You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This creates an unnecessary intermediate tensor and indexing operation. We should add rewrites to:
Convert arange(...)[slice(start, stop, step)] directly to an equivalent arange(...)
Convert arange(...)[scalar_index] to just the scalar value
In both cases, we need to take care of potentially negative indices (or None in slices). All variables in arange and the indexing may also be symbolic. However, we don't want to risk creating crazy graphs like #112
These optimizations would benefit both tensor and xtensor code, as the tensor rewrites would automatically apply after xtensor lowering.
Files to change
pytensor/tensor/rewriting/subtensor_lift.py: Add new optimization patterns
Implementation details
Two new pattern rewrites should be added:
A rewrite that converts Subtensor(ARange, slice(...)) to a direct ARange with adjusted parameters
A rewrite that converts Subtensor(ARange, scalar) to a constant or scalar value
Current behavior
If we profile or print the current graph for a function like:
Uh oh!
There was an error while loading. Please reload this page.
Description
In #1429 we use the following expression to lower some xtensor indexing operations to tensor operations:
This creates an unnecessary intermediate tensor and indexing operation. We should add rewrites to:
arange(...)[slice(start, stop, step)]
directly to an equivalentarange(...)
arange(...)[scalar_index]
to just the scalar valueIn both cases, we need to take care of potentially negative indices (or None in slices). All variables in arange and the indexing may also be symbolic. However, we don't want to risk creating crazy graphs like #112
These optimizations would benefit both tensor and xtensor code, as the tensor rewrites would automatically apply after xtensor lowering.
Files to change
pytensor/tensor/rewriting/subtensor_lift.py
: Add new optimization patternsImplementation details
Two new pattern rewrites should be added:
Subtensor(ARange, slice(...))
to a directARange
with adjusted parametersSubtensor(ARange, scalar)
to a constant or scalar valueCurrent behavior
If we profile or print the current graph for a function like:
We'll see that the compiled graph includes:
The text was updated successfully, but these errors were encountered: