diff --git a/pytensor/link/jax/dispatch/signal/conv.py b/pytensor/link/jax/dispatch/signal/conv.py index 1c124065e2..92414ac59a 100644 --- a/pytensor/link/jax/dispatch/signal/conv.py +++ b/pytensor/link/jax/dispatch/signal/conv.py @@ -1,11 +1,11 @@ import jax from pytensor.link.jax.dispatch import jax_funcify -from pytensor.tensor.signal.conv import Conv1d +from pytensor.tensor.signal.conv import Convolve1d -@jax_funcify.register(Conv1d) -def jax_funcify_Conv1d(op, node, **kwargs): +@jax_funcify.register(Convolve1d) +def jax_funcify_Convolve1d(op, node, **kwargs): mode = op.mode def conv1d(data, kernel): diff --git a/pytensor/link/numba/dispatch/signal/conv.py b/pytensor/link/numba/dispatch/signal/conv.py index b1c63a440c..cf163228ad 100644 --- a/pytensor/link/numba/dispatch/signal/conv.py +++ b/pytensor/link/numba/dispatch/signal/conv.py @@ -1,16 +1,70 @@ import numpy as np +from numba.np.arraymath import _get_inner_prod from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import numba_njit -from pytensor.tensor.signal.conv import Conv1d +from pytensor.tensor.signal.conv import Convolve1d -@numba_funcify.register(Conv1d) -def numba_funcify_Conv1d(op, node, **kwargs): +@numba_funcify.register(Convolve1d) +def numba_funcify_Convolve1d(op, node, **kwargs): + # This specialized version is faster than the overloaded numba np.convolve mode = op.mode + a_dtype, b_dtype = node.inputs[0].type.dtype, node.inputs[1].type.dtype + out_dtype = node.outputs[0].type.dtype + innerprod = _get_inner_prod(a_dtype, b_dtype) - @numba_njit - def conv1d(data, kernel): - return np.convolve(data, kernel, mode=mode) + if mode == "valid": - return conv1d + def valid_convolve1d(x, y): + nx = len(x) + ny = len(y) + if nx < ny: + x, y = y, x + nx, ny = ny, nx + y_flipped = y[::-1] + + length = nx - ny + 1 + ret = np.empty(length, out_dtype) + + for i in range(length): + ret[i] = innerprod(x[i : i + ny], y_flipped) + + return ret + + return numba_njit(valid_convolve1d) + + elif mode == "full": + + def full_convolve1d(x, y): + nx = len(x) + ny = len(y) + if nx < ny: + x, y = y, x + nx, ny = ny, nx + y_flipped = y[::-1] + + length = nx + ny - 1 + ret = np.empty(length, out_dtype) + idx = 0 + + for i in range(ny - 1): + k = i + 1 + ret[idx] = innerprod(x[:k], y_flipped[-k:]) + idx = idx + 1 + + for i in range(nx - ny + 1): + ret[idx] = innerprod(x[i : i + ny], y_flipped) + idx = idx + 1 + + for i in range(ny - 1): + k = ny - i - 1 + ret[idx] = innerprod(x[-k:], y_flipped[:k]) + idx = idx + 1 + + return ret + + return numba_njit(full_convolve1d) + + else: + raise ValueError(f"Unsupported mode: {mode}") diff --git a/pytensor/tensor/rewriting/__init__.py b/pytensor/tensor/rewriting/__init__.py index 4e75140ceb..80b844cfae 100644 --- a/pytensor/tensor/rewriting/__init__.py +++ b/pytensor/tensor/rewriting/__init__.py @@ -3,6 +3,7 @@ import pytensor.tensor.rewriting.blas_c import pytensor.tensor.rewriting.blas_scipy import pytensor.tensor.rewriting.blockwise +import pytensor.tensor.rewriting.conv import pytensor.tensor.rewriting.einsum import pytensor.tensor.rewriting.elemwise import pytensor.tensor.rewriting.extra_ops diff --git a/pytensor/tensor/rewriting/conv.py b/pytensor/tensor/rewriting/conv.py new file mode 100644 index 0000000000..37a3fdc00f --- /dev/null +++ b/pytensor/tensor/rewriting/conv.py @@ -0,0 +1,78 @@ +from pytensor.graph.basic import Constant +from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize +from pytensor.tensor.signal import convolve1d +from pytensor.tensor.signal.conv import Convolve1d +from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor + + +@register_stabilize +@register_specialize +@node_rewriter([Subtensor]) +def local_sliced_full_conv_to_valid_conv(fgraph, node): + """Rewrite sliced full conv that are equivalent to valid. + + The gradient of a valid Conv1d always implements the worst case scenario - full convolution - + because it would need to know which input is larger to do something smarter. + If we find out (through rewrites or static shape) we provide the direct implementation + which can be orders of magnitude faster. + + # if x.shape[-1] > y.shape[-1] + # z = convolve1d(x, y, mode="full") + # z[..., y.shape[-1] - 1: z.shape[-1] - y.shape[-1] - 1] -> convolve1d(x, y, mode="valid") + """ + conv, *other_idx_vars = node.inputs + + if not ( + conv.owner is not None + and isinstance(conv.owner.op, Blockwise) + and isinstance(conv.owner.op.core_op, Convolve1d) + and conv.owner.op.core_op.mode == "full" + ): + return None + + # Check we have an (a:b) constant slice at the last axis of the input + idx_list = node.op.idx_list + if not (len(idx_list) == conv.type.ndim and isinstance(idx_list[-1], slice)): + return None + + last_slice = idx_list[-1] + if not ( + last_slice.start is not None + and last_slice.stop is not None + and last_slice.step is None + ): + return None + + *other_idx_vars, start, stop = other_idx_vars + if not (isinstance(start, Constant) and isinstance(stop, Constant)): + return None + + x, y = conv.owner.inputs + len_x = x.type.shape[-1] + len_y = y.type.shape[-1] + if len_x is None or len_y is None: + return None + + start, stop = start.data, stop.data + if len_x < len_y: + # Convolution is symmetric with input order + x, y = y, x + len_x, len_y = len_y, len_x + + if ( + start == len_y - 1 + # equivalent to stop = conv.shape[-1] - len_y - 1 + and stop == start + (len_x - len_y) + 1 + ): + new_conv = convolve1d(x, y, mode="valid") + copy_stack_trace(conv, new_conv) + + if other_idx_vars: + # If there were more than just empty slices besides the last one + new_indices = indices_from_subtensor(idx_list[:-1], other_idx_vars) + new_conv = new_conv[new_indices] + copy_stack_trace(node.out, new_conv) + + return [new_conv] diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index ab2856b694..59a5a9ea9c 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -15,7 +15,7 @@ from pytensor.tensor import TensorLike -class Conv1d(Op): +class Convolve1d(Op): __props__ = ("mode",) gufunc_signature = "(n),(k)->(o)" @@ -75,13 +75,14 @@ def L_op(self, inputs, outputs, output_grads): n = in1.shape[0] k = in2.shape[0] kmn = maximum(0, k - n) - nkm = maximum(0, n - k) + nmk = maximum(0, n - k) # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic. # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter. + # There is a rewrite that optimizes this case when n, k are static in1_bar = full_conv(grad, in2[::-1]) in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn] in2_bar = full_conv(grad, in1[::-1]) - in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm] + in2_bar = in2_bar[nmk : in2_bar.shape[0] - nmk] return [in1_bar, in2_bar] @@ -129,4 +130,4 @@ def convolve1d( ) mode = "valid" - return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2)) + return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2)) diff --git a/tests/link/numba/signal/test_conv.py b/tests/link/numba/signal/test_conv.py index 1a72c2df0b..d1e90a6dae 100644 --- a/tests/link/numba/signal/test_conv.py +++ b/tests/link/numba/signal/test_conv.py @@ -1,7 +1,10 @@ +from functools import partial + import numpy as np import pytest -from pytensor.tensor import dmatrix +from pytensor import function +from pytensor.tensor import dmatrix, tensor from pytensor.tensor.signal import convolve1d from tests.link.numba.test_basic import compare_numba_and_py @@ -10,13 +13,47 @@ @pytest.mark.parametrize("mode", ["full", "valid", "same"]) -def test_convolve1d(mode): +@pytest.mark.parametrize("x_smaller", (False, True)) +def test_convolve1d(x_smaller, mode): x = dmatrix("x") y = dmatrix("y") - out = convolve1d(x[None], y[:, None], mode=mode) + if x_smaller: + out = convolve1d(x[None], y[:, None], mode=mode) + else: + out = convolve1d(y[:, None], x[None], mode=mode) rng = np.random.default_rng() test_x = rng.normal(size=(3, 5)) test_y = rng.normal(size=(7, 11)) # Blockwise dispatch for numba can't be run on object mode compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False) + + +@pytest.mark.parametrize("mode", ("full", "valid"), ids=lambda x: f"mode={x}") +@pytest.mark.parametrize("batch", (False, True), ids=lambda x: f"batch={x}") +def test_convolve1d_benchmark(batch, mode, benchmark): + x = tensor( + shape=( + 7, + 183, + ) + if batch + else (183,) + ) + y = tensor(shape=(7, 6) if batch else (6,)) + out = convolve1d(x, y, mode=mode) + fn = function([x, y], out, mode="NUMBA", trust_input=True) + + rng = np.random.default_rng() + x_test = rng.normal(size=(x.type.shape)).astype(x.type.dtype) + y_test = rng.normal(size=(y.type.shape)).astype(y.type.dtype) + + np_convolve1d = np.vectorize( + partial(np.convolve, mode=mode), signature="(x),(y)->(z)" + ) + + np.testing.assert_allclose( + fn(x_test, y_test), + np_convolve1d(x_test, y_test), + ) + benchmark(fn, x_test, y_test) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index d56d365193..d6b0d69d7c 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -5,10 +5,11 @@ from scipy.signal import convolve as scipy_convolve from pytensor import config, function, grad -from pytensor.graph import ancestors, rewrite_graph +from pytensor.graph.basic import ancestors, io_toposort +from pytensor.graph.rewriting import rewrite_graph from pytensor.tensor import matrix, vector from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.signal.conv import Conv1d, convolve1d +from pytensor.tensor.signal.conv import Convolve1d, convolve1d from tests import unittest_tools as utt @@ -81,4 +82,30 @@ def test_convolve1d_batch_graph(mode): if var.owner is not None and isinstance(var.owner.op, Blockwise) ] # Check any Blockwise are just Conv1d - assert all(isinstance(node.op.core_op, Conv1d) for node in blockwise_nodes) + assert all(isinstance(node.op.core_op, Convolve1d) for node in blockwise_nodes) + + +@pytest.mark.parametrize("static_shape", [False, True]) +def test_convolve1d_valid_grad_rewrite(static_shape): + """Test that we don't do a useless full convolve1d when taking the gradient of a valid convolve wrt to the smallest input. + + This can only be achieved when the two inputs have static shapes, so we know which one is larger + """ + larger = vector("larger", shape=(128 if static_shape else None,)) + smaller = vector("smaller", shape=(64 if static_shape else None,)) + out = convolve1d(larger, smaller, mode="valid") + grad_out = rewrite_graph( + grad(out.sum(), wrt=smaller), + include=( + "ShapeOpt", + "canonicalize", + "stabilize", + "local_useless_unbatched_blockwise", + ), + ) + [conv_op] = [ + node.op + for node in io_toposort([larger, smaller], [grad_out]) + if isinstance(node.op, Convolve1d) + ] + assert conv_op.mode == ("valid" if static_shape else "full")