-
Notifications
You must be signed in to change notification settings - Fork 132
Faster convolve1d in numba backend #1378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
66fa69a
to
02823cc
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (50.53%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1378 +/- ##
==========================================
- Coverage 82.07% 82.01% -0.06%
==========================================
Files 206 207 +1
Lines 49174 49250 +76
Branches 8720 8734 +14
==========================================
+ Hits 40359 40394 +35
- Misses 6656 6692 +36
- Partials 2159 2164 +5
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR reimplements the core logic of convolve1d in the numba backend for a 6× speedup in benchmarks with small inputs, while also optimizing the gradient computation for valid convolutions when the smaller input’s shape is known statically. In addition, the PR renames Conv1d to Convolve1d for improved consistency in function naming and updates various test and dispatch files to reflect these changes.
- Renames Conv1d to Convolve1d across modules.
- Adds new tests for gradient optimization and benchmarks for numba convolve1d.
- Updates rewriting and dispatch code to support the new implementation.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/tensor/signal/test_conv.py | Updated to import Convolve1d and added a test for gradient rewrite optimization. |
tests/link/numba/signal/test_conv.py | Adjusted tests to optionally swap inputs, and added a benchmark test. |
pytensor/tensor/signal/conv.py | Renamed Conv1d to Convolve1d and updated internal variable naming for clarity. |
pytensor/tensor/rewriting/conv.py | Added a rewrite rule to optimize valid convolution gradients for static shapes. |
pytensor/tensor/rewriting/init.py | Imported the new conv rewriting module. |
pytensor/link/numba/dispatch/signal/conv.py | Updated to register Convolve1d and implemented specialized numba functions. |
pytensor/link/jax/dispatch/signal/conv.py | Updated to register Convolve1d. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, left ignorable suggestions
|
||
if ( | ||
start == len_y - 1 | ||
# equivalent to stop = conv.shape[-1] - len_y - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not use that form then? I don't understand this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I already extracted len_x, and I can use that directly
02823cc
to
f1102ba
Compare
f1102ba
to
e2c8464
Compare
These show up in the gradient of Convolve1D
e2c8464
to
f0ef8fb
Compare
Reimplementing the core logic in the numba overload of
convolve/correlate
gives a speedup of 6x in the benchmarked test with relatively small inputs. I guess the overloads don't optimize/propagate constant checks as well? It's a bit surprising but the results are crystal clear.Also added a rewrite to optimize the gradient of valid convolutions wrt to the smallest inputs, in which case we don't need a full convolve. This is done at the rewrite level because static shape may not be known at the time of grad.
Finally, renamed
Conv1d
toConvolve1d
which is more in line with the user-facing function