-
Notifications
You must be signed in to change notification settings - Fork 133
local_subtensor_merge
can complicate graphs
#112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Yes, it may make sense to restrict it to static indexes cases. Is this a realistic example though? |
I agree with @aseyboldt and I see this example as realistic enough. You can use such nested indexing to represent grouped hierarchy, but not sure if it hits this particular thing. |
A problem is that I think some scan rewrites depend on this. But I think we should probably restrict this one to static indexes or so and worry about fixing scan later. |
local_subtensor_merge
can complicate graphs
This gets utterly insane on gradient of scans: from pytensor import scan, function
from pytensor.compile.mode import get_mode
import pytensor.tensor as pt
x0 = pt.scalar("x0")
n = pt.scalar("n", dtype=int)
xs, _ = scan(
fn=lambda xtm1: xtm1 ** 2,
outputs_info=[x0],
n_steps=n,
)
grad_xs_wrt_x0 = pt.grad(xs[-1], x0)
fn = function([n, x0], grad_xs_wrt_x0, mode=get_mode("fast_run").including("local_subtensor_merge"))
fn.dprint(print_shape=True, print_op_info=True) Try it with and without the rewrite. Without
With
|
For scan I think we're mostly interested in 2 cases: |
It's so bad I'm marking it as a bug |
Uh oh!
There was an error while loading. Please reload this page.
Description
The local_subtensor_merge op often makes graph worse instead of better:
https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/rewriting/subtensor.py#L475
After:
I think this rewrite might be fine in some special cases with known shapes/indices, but in general I don't see why we would do this rewrite.
The text was updated successfully, but these errors were encountered: