-
Notifications
You must be signed in to change notification settings - Fork 132
Implement batched convolve1d #1318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import pytensor.link.jax.dispatch.signal.conv |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import jax | ||
|
||
from pytensor.link.jax.dispatch import jax_funcify | ||
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
@jax_funcify.register(Conv1d) | ||
def jax_funcify_Conv1d(op, node, **kwargs): | ||
mode = op.mode | ||
|
||
def conv1d(data, kernel): | ||
return jax.numpy.convolve(data, kernel, mode=mode) | ||
|
||
return conv1d |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import pytensor.link.numba.dispatch.signal.conv |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import numpy as np | ||
|
||
from pytensor.link.numba.dispatch import numba_funcify | ||
from pytensor.link.numba.dispatch.basic import numba_njit | ||
from pytensor.tensor.signal.conv import Conv1d | ||
|
||
|
||
@numba_funcify.register(Conv1d) | ||
def numba_funcify_Conv1d(op, node, **kwargs): | ||
mode = op.mode | ||
|
||
@numba_njit | ||
def conv1d(data, kernel): | ||
return np.convolve(data, kernel, mode=mode) | ||
|
||
return conv1d |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from pytensor.tensor.signal.conv import convolve1d | ||
|
||
|
||
__all__ = ("convolve1d",) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from typing import TYPE_CHECKING, Literal, cast | ||
|
||
from numpy import convolve as numpy_convolve | ||
|
||
from pytensor.graph import Apply, Op | ||
from pytensor.scalar.basic import upcast | ||
from pytensor.tensor.basic import as_tensor_variable, join, zeros | ||
from pytensor.tensor.blockwise import Blockwise | ||
from pytensor.tensor.math import maximum, minimum | ||
from pytensor.tensor.type import vector | ||
from pytensor.tensor.variable import TensorVariable | ||
|
||
|
||
if TYPE_CHECKING: | ||
from pytensor.tensor import TensorLike | ||
|
||
|
||
class Conv1d(Op): | ||
__props__ = ("mode",) | ||
gufunc_signature = "(n),(k)->(o)" | ||
|
||
def __init__(self, mode: Literal["full", "valid"] = "full"): | ||
if mode not in ("full", "valid"): | ||
raise ValueError(f"Invalid mode: {mode}") | ||
self.mode = mode | ||
|
||
def make_node(self, in1, in2): | ||
in1 = as_tensor_variable(in1) | ||
in2 = as_tensor_variable(in2) | ||
|
||
assert in1.ndim == 1 | ||
assert in2.ndim == 1 | ||
|
||
dtype = upcast(in1.dtype, in2.dtype) | ||
|
||
n = in1.type.shape[0] | ||
k = in2.type.shape[0] | ||
|
||
if n is None or k is None: | ||
out_shape = (None,) | ||
elif self.mode == "full": | ||
out_shape = (n + k - 1,) | ||
else: # mode == "valid": | ||
out_shape = (max(n, k) - min(n, k) + 1,) | ||
|
||
out = vector(dtype=dtype, shape=out_shape) | ||
return Apply(self, [in1, in2], [out]) | ||
|
||
def perform(self, node, inputs, outputs): | ||
# We use numpy_convolve as that's what scipy would use if method="direct" was passed. | ||
# And mode != "same", which this Op doesn't cover anyway. | ||
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode) | ||
|
||
def infer_shape(self, fgraph, node, shapes): | ||
in1_shape, in2_shape = shapes | ||
n = in1_shape[0] | ||
k = in2_shape[0] | ||
if self.mode == "full": | ||
shape = n + k - 1 | ||
else: # mode == "valid": | ||
shape = maximum(n, k) - minimum(n, k) + 1 | ||
return [[shape]] | ||
|
||
def L_op(self, inputs, outputs, output_grads): | ||
in1, in2 = inputs | ||
[grad] = output_grads | ||
|
||
if self.mode == "full": | ||
valid_conv = type(self)(mode="valid") | ||
in1_bar = valid_conv(grad, in2[::-1]) | ||
in2_bar = valid_conv(grad, in1[::-1]) | ||
|
||
else: # mode == "valid": | ||
full_conv = type(self)(mode="full") | ||
n = in1.shape[0] | ||
k = in2.shape[0] | ||
kmn = maximum(0, k - n) | ||
nkm = maximum(0, n - k) | ||
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic. | ||
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter. | ||
in1_bar = full_conv(grad, in2[::-1]) | ||
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn] | ||
in2_bar = full_conv(grad, in1[::-1]) | ||
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm] | ||
|
||
return [in1_bar, in2_bar] | ||
|
||
|
||
def convolve1d( | ||
in1: "TensorLike", | ||
in2: "TensorLike", | ||
mode: Literal["full", "valid", "same"] = "full", | ||
) -> TensorVariable: | ||
"""Convolve two one-dimensional arrays. | ||
|
||
Convolve in1 and in2, with the output size determined by the mode argument. | ||
|
||
Parameters | ||
---------- | ||
in1 : (..., N,) tensor_like | ||
First input. | ||
in2 : (..., M,) tensor_like | ||
Second input. | ||
mode : {'full', 'valid', 'same'}, optional | ||
A string indicating the size of the output: | ||
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,). | ||
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,). | ||
- 'same': The output is the same size as in1, centered with respect to the 'full' output. | ||
|
||
Returns | ||
------- | ||
out: tensor_variable | ||
The discrete linear convolution of in1 with in2. | ||
|
||
""" | ||
in1 = as_tensor_variable(in1) | ||
in2 = as_tensor_variable(in2) | ||
|
||
if mode == "same": | ||
# We implement "same" as "valid" with padded `in1`. | ||
in1_batch_shape = tuple(in1.shape)[:-1] | ||
zeros_left = in2.shape[0] // 2 | ||
zeros_right = (in2.shape[0] - 1) // 2 | ||
in1 = join( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want pad until we figure out inline for it. I want PyTensor to optimize across the boundary, specially when gradients get involved |
||
-1, | ||
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype), | ||
in1, | ||
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype), | ||
) | ||
mode = "valid" | ||
|
||
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2)) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,18 @@ | ||||||
import numpy as np | ||||||
import pytest | ||||||
|
||||||
from pytensor.tensor import dmatrix | ||||||
from pytensor.tensor.signal import convolve1d | ||||||
from tests.link.jax.test_basic import compare_jax_and_py | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("mode", ["full", "valid", "same"]) | ||||||
def test_convolve1d(mode): | ||||||
x = dmatrix("x") | ||||||
y = dmatrix("y") | ||||||
out = convolve1d(x[None], y[:, None], mode=mode) | ||||||
|
||||||
rng = np.random.default_rng() | ||||||
test_x = rng.normal(size=(3, 5)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope no stupid float32 in jax/numba tests and also I explicitly used dmatrix |
||||||
test_y = rng.normal(size=(7, 11)) | ||||||
compare_jax_and_py([x, y], out, [test_x, test_y]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor.tensor import dmatrix | ||
from pytensor.tensor.signal import convolve1d | ||
from tests.link.numba.test_basic import compare_numba_and_py | ||
|
||
|
||
pytestmark = pytest.mark.filterwarnings("error") | ||
|
||
|
||
@pytest.mark.parametrize("mode", ["full", "valid", "same"]) | ||
def test_convolve1d(mode): | ||
x = dmatrix("x") | ||
y = dmatrix("y") | ||
out = convolve1d(x[None], y[:, None], mode=mode) | ||
|
||
rng = np.random.default_rng() | ||
test_x = rng.normal(size=(3, 5)) | ||
test_y = rng.normal(size=(7, 11)) | ||
# Blockwise dispatch for numba can't be run on object mode | ||
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from functools import partial | ||
|
||
import numpy as np | ||
import pytest | ||
from scipy.signal import convolve as scipy_convolve | ||
|
||
from pytensor import config, function | ||
from pytensor.tensor import matrix, vector | ||
from pytensor.tensor.signal.conv import convolve1d | ||
from tests import unittest_tools as utt | ||
|
||
|
||
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}") | ||
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}") | ||
@pytest.mark.parametrize("mode", ["full", "valid", "same"]) | ||
def test_convolve1d(mode, data_shape, kernel_shape): | ||
data = vector("data") | ||
kernel = vector("kernel") | ||
op = partial(convolve1d, mode=mode) | ||
|
||
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode)))) | ||
data_val = rng.normal(size=data_shape).astype(data.dtype) | ||
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype) | ||
|
||
fn = function([data, kernel], op(data, kernel)) | ||
np.testing.assert_allclose( | ||
fn(data_val, kernel_val), | ||
scipy_convolve(data_val, kernel_val, mode=mode), | ||
rtol=1e-6 if config.floatX == "float32" else 1e-15, | ||
) | ||
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val]) | ||
|
||
|
||
def test_convolve1d_batch(): | ||
x = matrix("data") | ||
y = matrix("kernel") | ||
out = convolve1d(x, y) | ||
|
||
rng = np.random.default_rng(38) | ||
x_test = rng.normal(size=(2, 8)).astype(x.dtype) | ||
y_test = x_test[::-1] | ||
|
||
res = out.eval({x: x_test, y: y_test}) | ||
# Second entry of x, y are just y, x respectively, | ||
# so res[0] and res[1] should be identical. | ||
rtol = 1e-6 if config.floatX == "float32" else 1e-15 | ||
res_np = np.convolve(x_test[0], y_test[0]) | ||
np.testing.assert_allclose(res[0], res_np, rtol=rtol) | ||
np.testing.assert_allclose(res[1], res_np, rtol=rtol) |
Uh oh!
There was an error while loading. Please reload this page.