Skip to content

Faster convolve1d in numba backend #1378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytensor/link/jax/dispatch/signal/conv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import jax

from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.signal.conv import Conv1d
from pytensor.tensor.signal.conv import Convolve1d


@jax_funcify.register(Conv1d)
def jax_funcify_Conv1d(op, node, **kwargs):
@jax_funcify.register(Convolve1d)
def jax_funcify_Convolve1d(op, node, **kwargs):
mode = op.mode

def conv1d(data, kernel):
Expand Down
68 changes: 61 additions & 7 deletions pytensor/link/numba/dispatch/signal/conv.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,70 @@
import numpy as np
from numba.np.arraymath import _get_inner_prod

from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import numba_njit
from pytensor.tensor.signal.conv import Conv1d
from pytensor.tensor.signal.conv import Convolve1d


@numba_funcify.register(Conv1d)
def numba_funcify_Conv1d(op, node, **kwargs):
@numba_funcify.register(Convolve1d)
def numba_funcify_Convolve1d(op, node, **kwargs):
# This specialized version is faster than the overloaded numba np.convolve
mode = op.mode
a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype
out_dtype = node.outputs[0].type.dtype
innerprod = _get_inner_prod(a_dtype, b_dtype)

@numba_njit
def conv1d(data, kernel):
return np.convolve(data, kernel, mode=mode)
if mode == "valid":

return conv1d
def valid_convolve1d(x, y):
nx = len(x)
ny = len(y)

Check warning on line 21 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L20-L21

Added lines #L20 - L21 were not covered by tests
if nx < ny:
x, y = y, x
nx, ny = ny, nx
y_flipped = y[::-1]

Check warning on line 25 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L23-L25

Added lines #L23 - L25 were not covered by tests

length = nx - ny + 1
ret = np.empty(length, out_dtype)

Check warning on line 28 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L27-L28

Added lines #L27 - L28 were not covered by tests

for i in range(length):
ret[i] = innerprod(x[i : i + ny], y_flipped)

Check warning on line 31 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L31

Added line #L31 was not covered by tests

return ret

Check warning on line 33 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L33

Added line #L33 was not covered by tests

return numba_njit(valid_convolve1d)

elif mode == "full":

def full_convolve1d(x, y):
nx = len(x)
ny = len(y)

Check warning on line 41 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L40-L41

Added lines #L40 - L41 were not covered by tests
if nx < ny:
x, y = y, x
nx, ny = ny, nx
y_flipped = y[::-1]

Check warning on line 45 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L43-L45

Added lines #L43 - L45 were not covered by tests

length = nx + ny - 1
ret = np.empty(length, out_dtype)
idx = 0

Check warning on line 49 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L47-L49

Added lines #L47 - L49 were not covered by tests

for i in range(ny - 1):
k = i + 1
ret[idx] = innerprod(x[:k], y_flipped[-k:])
idx = idx + 1

Check warning on line 54 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L52-L54

Added lines #L52 - L54 were not covered by tests

for i in range(nx - ny + 1):
ret[idx] = innerprod(x[i : i + ny], y_flipped)
idx = idx + 1

Check warning on line 58 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L57-L58

Added lines #L57 - L58 were not covered by tests

for i in range(ny - 1):
k = ny - i - 1
ret[idx] = innerprod(x[-k:], y_flipped[:k])
idx = idx + 1

Check warning on line 63 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L61-L63

Added lines #L61 - L63 were not covered by tests

return ret

Check warning on line 65 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L65

Added line #L65 was not covered by tests

return numba_njit(full_convolve1d)

else:
raise ValueError(f"Unsupported mode: {mode}")

Check warning on line 70 in pytensor/link/numba/dispatch/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/signal/conv.py#L70

Added line #L70 was not covered by tests
1 change: 1 addition & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.blockwise
import pytensor.tensor.rewriting.conv
import pytensor.tensor.rewriting.einsum
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops
Expand Down
78 changes: 78 additions & 0 deletions pytensor/tensor/rewriting/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.signal.conv import Convolve1d
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor


@register_stabilize
@register_specialize
@node_rewriter([Subtensor])
def local_sliced_full_conv_to_valid_conv(fgraph, node):
"""Rewrite sliced full conv that are equivalent to valid.

The gradient of a valid Conv1d always implements the worst case scenario - full convolution -
because it would need to know which input is larger to do something smarter.
If we find out (through rewrites or static shape) we provide the direct implementation
which can be orders of magnitude faster.

# if x.shape[-1] > y.shape[-1]
# z = convolve1d(x, y, mode="full")
# z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid")
"""
conv, *other_idx_vars = node.inputs

if not (
conv.owner is not None
and isinstance(conv.owner.op, Blockwise)
and isinstance(conv.owner.op.core_op, Convolve1d)
and conv.owner.op.core_op.mode == "full"
):
return None

# Check we have an (a:b) constant slice at the last axis of the input
idx_list = node.op.idx_list
if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)):
return None

Check warning on line 38 in pytensor/tensor/rewriting/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/conv.py#L38

Added line #L38 was not covered by tests

last_slice = idx_list[-1]
if not (
last_slice.start is not None
and last_slice.stop is not None
and last_slice.step is None
):
return None

Check warning on line 46 in pytensor/tensor/rewriting/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/conv.py#L46

Added line #L46 was not covered by tests

*other_idx_vars, start, stop = other_idx_vars
if not (isinstance(start, Constant) and isinstance(stop, Constant)):
return None

x, y = conv.owner.inputs
len_x = x.type.shape[-1]
len_y = y.type.shape[-1]
if len_x is None or len_y is None:
return None

Check warning on line 56 in pytensor/tensor/rewriting/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/conv.py#L56

Added line #L56 was not covered by tests

start, stop = start.data, stop.data
if len_x < len_y:
# Convolution is symmetric with input order
x, y = y, x
len_x, len_y = len_y, len_x

if (
start == len_y - 1
# equivalent to stop = conv.shape[-1] - len_y - 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use that form then? I don't understand this comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I already extracted len_x, and I can use that directly

and stop == start + (len_x - len_y) + 1
):
new_conv = convolve1d(x, y, mode="valid")
copy_stack_trace(conv, new_conv)

if other_idx_vars:
# If there were more than just empty slices besides the last one
new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars)
new_conv = new_conv[new_indices]
copy_stack_trace(node.out, new_conv)

Check warning on line 76 in pytensor/tensor/rewriting/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/conv.py#L74-L76

Added lines #L74 - L76 were not covered by tests

return [new_conv]
9 changes: 5 additions & 4 deletions pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.tensor import TensorLike


class Conv1d(Op):
class Convolve1d(Op):
__props__ = ("mode",)
gufunc_signature = "(n),(k)->(o)"

Expand Down Expand Up @@ -75,13 +75,14 @@ def L_op(self, inputs, outputs, output_grads):
n = in1.shape[0]
k = in2.shape[0]
kmn = maximum(0, k - n)
nkm = maximum(0, n - k)
nmk = maximum(0, n - k)
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
# There is a rewrite that optimizes this case when n, k are static
in1_bar = full_conv(grad, in2[::-1])
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
in2_bar = full_conv(grad, in1[::-1])
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk]

return [in1_bar, in2_bar]

Expand Down Expand Up @@ -129,4 +130,4 @@ def convolve1d(
)
mode = "valid"

return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))
return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))
43 changes: 40 additions & 3 deletions tests/link/numba/signal/test_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from functools import partial

import numpy as np
import pytest

from pytensor.tensor import dmatrix
from pytensor import function
from pytensor.tensor import dmatrix, tensor
from pytensor.tensor.signal import convolve1d
from tests.link.numba.test_basic import compare_numba_and_py

Expand All @@ -10,13 +13,47 @@


@pytest.mark.parametrize("mode", ["full", "valid", "same"])
def test_convolve1d(mode):
@pytest.mark.parametrize("x_smaller", (False, True))
def test_convolve1d(x_smaller, mode):
x = dmatrix("x")
y = dmatrix("y")
out = convolve1d(x[None], y[:, None], mode=mode)
if x_smaller:
out = convolve1d(x[None], y[:, None], mode=mode)
else:
out = convolve1d(y[:, None], x[None], mode=mode)

rng = np.random.default_rng()
test_x = rng.normal(size=(3, 5))
test_y = rng.normal(size=(7, 11))
# Blockwise dispatch for numba can't be run on object mode
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)


@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}")
@pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}")
def test_convolve1d_benchmark(batch, mode, benchmark):
x = tensor(
shape=(
7,
183,
)
if batch
else (183,)
)
y = tensor(shape=(7, 6) if batch else (6,))
out = convolve1d(x, y, mode=mode)
fn = function([x, y], out, mode="NUMBA", trust_input=True)

rng = np.random.default_rng()
x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype)
y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype)

np_convolve1d = np.vectorize(
partial(np.convolve, mode=mode), signature="(x),(y)->(z)"
)

np.testing.assert_allclose(
fn(x_test, y_test),
np_convolve1d(x_test, y_test),
)
benchmark(fn, x_test, y_test)
33 changes: 30 additions & 3 deletions tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from scipy.signal import convolve as scipy_convolve

from pytensor import config, function, grad
from pytensor.graph import ancestors, rewrite_graph
from pytensor.graph.basic import ancestors, io_toposort
from pytensor.graph.rewriting import rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Conv1d, convolve1d
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from tests import unittest_tools as utt


Expand Down Expand Up @@ -81,4 +82,30 @@ def test_convolve1d_batch_graph(mode):
if var.owner is not None and isinstance(var.owner.op, Blockwise)
]
# Check any Blockwise are just Conv1d
assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes)
assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes)


@pytest.mark.parametrize("static_shape", [False, True])
def test_convolve1d_valid_grad_rewrite(static_shape):
"""Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input.

This can only be achieved when the two inputs have static shapes, so we know which one is larger
"""
larger = vector("larger", shape=(128 if static_shape else None,))
smaller = vector("smaller", shape=(64 if static_shape else None,))
out = convolve1d(larger, smaller, mode="valid")
grad_out = rewrite_graph(
grad(out.sum(), wrt=smaller),
include=(
"ShapeOpt",
"canonicalize",
"stabilize",
"local_useless_unbatched_blockwise",
),
)
[conv_op] = [
node.op
for node in io_toposort([larger, smaller], [grad_out])
if isinstance(node.op, Convolve1d)
]
assert conv_op.mode == ("valid" if static_shape else "full")