Skip to content

Optimize slices of arange #1431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue May 30, 2025 · 0 comments
Open

Optimize slices of arange #1431

ricardoV94 opened this issue May 30, 2025 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented May 30, 2025

Description

In #1429 we use the following expression to lower some xtensor indexing operations to tensor operations:

adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]

This creates an unnecessary intermediate tensor and indexing operation. We should add rewrites to:

  1. Convert arange(...)[slice(start, stop, step)] directly to an equivalent arange(...)
  2. Convert arange(...)[scalar_index] to just the scalar value

In both cases, we need to take care of potentially negative indices (or None in slices). All variables in arange and the indexing may also be symbolic. However, we don't want to risk creating crazy graphs like #112

These optimizations would benefit both tensor and xtensor code, as the tensor rewrites would automatically apply after xtensor lowering.

Files to change

  • pytensor/tensor/rewriting/subtensor_lift.py: Add new optimization patterns

Implementation details

Two new pattern rewrites should be added:

  1. A rewrite that converts Subtensor(ARange, slice(...)) to a direct ARange with adjusted parameters
  2. A rewrite that converts Subtensor(ARange, scalar) to a constant or scalar value

Current behavior

If we profile or print the current graph for a function like:

import pytensor
import pytensor.tensor as pt

shape_val = pt.lscalar("shape_val")
intermediate = pt.arange(shape_val)[:1]
fn = pytensor.function([shape_val], intermediate)
fn.dprint()
# Subtensor{:stop} [id A] 1
#  ├─ ARange{dtype='int64'} [id B] 0
#  │  ├─ 0 [id C]
#  │  ├─ shape_val [id D]
#  │  └─ 1 [id E]
#  └─ 1 [id F]

We'll see that the compiled graph includes:

  1. An ARange operation to create a sequence from 0 to shape_val
  2. A Subtensor operation to apply the slice that keeps only one value (in this case)
  3. This creates an unnecessary intermediate tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant