Skip to content

Implement indexing operations for XTensorVariables #1429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 28, 2025

Split the work on indexing from #1411

This implements all forms of non-boolean indexing in xarray. I think the same framework works for boolean indices which are restricted to be vectors in xarray, but I haven't implemented those yet.

It deviates from xarray in that it doesn't try to prevent runtime broadcasting between indexers when lowering to tensor index operations. There's a test marked as xfail. I'm weary of complicating the graph with a bunch of specify_shape but it's something we can consider doing if users complain.

It does prevents compile-time known broadcasting (so if everything has static shapes it will error).

@ricardoV94 ricardoV94 marked this pull request as ready for review May 29, 2025 11:51
@ricardoV94 ricardoV94 requested review from Copilot and OriolAbril May 29, 2025 11:52
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements full indexing support for XTensorVariables by adding indexing operations (__getitem__, isel, head/tail/thin), lowering logic in the rewrite pass, and comprehensive tests.

  • Add a new suite of indexing tests in tests/xtensor/test_indexing.py and cover scalar case in test_math.py
  • Implement __getitem__, isel, and related slicing utilities in pytensor/xtensor/type.py
  • Introduce an Index op, lowering logic (rewriting/indexing.py), and integrate with vectorization fallbacks

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/xtensor/test_math.py Add test_scalar_case for zero-dimensional tensors
tests/xtensor/test_indexing.py New tests covering basic, advanced, vector, matrix, and scalar indexing
pytensor/xtensor/vectorization.py Handle empty dims_and_shape in vectorization
pytensor/xtensor/type.py Implement __getitem__, isel, and dimension checks
pytensor/xtensor/indexing.py New Index op and helpers for converting indices
pytensor/xtensor/rewriting/indexing.py Lower XTensor indexing to TensorVariable indexing
pytensor/xtensor/rewriting/init.py Register the new indexing rewriter
pytensor/xtensor/init.py Remove unused XTensorType import
Comments suppressed due to low confidence (2)

tests/xtensor/test_indexing.py:32

  • [nitpick] The variable name shufled_dims is misspelled and ambiguous; consider renaming it to shuffled_dims.
        shufled_dims = tuple(np.random.permutation(dims))

pytensor/xtensor/indexing.py:3

  • Typo in comment: Uselful should be Useful.
# Uselful links to make sense of all the numpy/xarray complexity

@ricardoV94 ricardoV94 changed the title Implement index operations for XTensorVariables Implement non-boolean index operations for XTensorVariables May 29, 2025
@ricardoV94 ricardoV94 mentioned this pull request May 29, 2025
31 tasks
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read through the code and couldn't think of cases that would break, I also played for a while and everything seemed to work. I left some minor examples.

The main comment is I think we need more tests. The existing ones cover most branches already though. Happy to add some more here or later on.

@OriolAbril
Copy link
Member

This implements all forms of non-boolean indexing in xarray. I think the same framework works for boolean indices which are restricted to be vectors in xarray, but I haven't implemented those yet.

I agree, converting the provided boolean indexes to their integer equivalent as first step should be the only thing needed to support them.

@ricardoV94
Copy link
Member Author

Extra tests welcome @OriolAbril just push

@ricardoV94
Copy link
Member Author

I agree, converting the provided boolean indexes to their integer equivalent as first step should be the only thing needed to support them

Do we even need to convert?

@OriolAbril
Copy link
Member

OriolAbril commented May 29, 2025

Do we even need to convert?

I guess it will depend on what exactly the lowering does. For example:

# obs is DataArray (chain: 4, draw: 500, obs_id: 77)
# chain_idxs, draw_idxs are length 200 integer DataArrays with dims ("divergence_id",)
# mask is a length 77 boolean array with 16 Trues in it.

# This works
obs.isel(chain=chain_idxs, draw=draw_idxs, obs_id=mask)
# obs[chain_idxs, draw_idxs, mask]

# but both attempts below fail
obs.values[chain_idxs.values, draw_idxs.values, mask]
# IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (200,) (16,) 
obs.values[chain_idxs.values[:, None], draw_idxs.values[:, None], mask[None, :]]
# IndexError: too many indices

IIUC, the current lowering will generate indexes like I do in the 2nd failed attempt and forward to tensor indexing.

On the other hand the snippet below does work:

mask_idx, = np.nonzero(mask)
obs.values[chain_idxs.values[:, None], draw_idxs.values[:, None], mask_idx[None, :]]

@ricardoV94
Copy link
Member Author

@OriolAbril thanks for checking, it seems we need to convert to integers then for the general case.

I'm gonna see if I'm missing something and/or if it's worth to keep booleans when they are the only advanced index group.

Either way I would leave it for a separate PR, as tests will certainly be needed.

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 29, 2025

Another point in favor is that np.ix_ also converts bools with nonzero(): https://github.com/numpy/numpy/blob/7d2e4418e17eeaed1045e37335bacf0f01343f4d/numpy/lib/_index_tricks_impl.py#L102

So now the question is only if we want to ever not do it? Maybe we add a special branch if there's a single bool index and there's no other advanced indexing going on (or at least no other advanced orthogonal indexing)?

@OriolAbril
Copy link
Member

So now the question is only if we want to ever not do it? Maybe we add a special branch if there's a single bool index and there's no other advanced indexing going on (or at least no other advanced orthogonal indexing)?

It looks like that is the best way to go. From https://numpy.org/devdocs//user/basics.indexing.html#boolean-array-indexing:

A single boolean index array is practically identical to x[obj.nonzero()] where, as described above, obj.nonzero() returns a tuple (of length obj.ndim) of integer index arrays showing the True elements of obj. However, it is faster when obj.shape == x.shape.

@OriolAbril
Copy link
Member

but also with xarray rules on boolean indexes x.shape == obj.shape can only happen for 1d arrays 🤔

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 30, 2025

but also with xarray rules on boolean indexes x.shape == obj.shape can only happen for 1d arrays 🤔

Yeah but the following are fine:

import numpy as np
import xarray as xr

x = xr.DataArray(np.zeros((5, 5)), dims=("a", "b"))
x.values[np.array([True, False, False, True, True]), 1:]
x.values[np.array([True, False, False, True, True]), 0]
x.values[np.array([True, False, False, True, True]), [0, 2, 4]]

Now I think in the last case numpy will convert the boolean array to nonzero anyway, but not in the first two? Hence if that's a tiny bit more optimal we can leave it as a special case with boolean.

There's also the issue that we are not using the C implementation of NonZero yet, so it adds a tiny bit extra overhead in the default backend: #1361

OTOH if nonzero is a constant (most cases), it may be better. I'm inclined to going with .nonzero() for now and thinking about performance more carefully later

@ricardoV94
Copy link
Member Author

I did some benchmarks and boolean indexing doesn't seem to be any faster, at least the way PyTensor uses it so I suggest we actually get rid of them and always use nonzero, both here and later in tensor operations as well: #1432

@ricardoV94 ricardoV94 changed the title Implement non-boolean index operations for XTensorVariables Implement indexing operations for XTensorVariables May 30, 2025
@ricardoV94
Copy link
Member Author

@OriolAbril I went and lifted the boolean restriction, always converting it with nonzero

@ricardoV94 ricardoV94 requested review from OriolAbril and Copilot May 30, 2025 10:09
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds full non-boolean indexing support for XTensorVariables and related tests

  • Implements __getitem__, isel, head/tail/thin on XTensorVariable with ellipsis handling
  • Introduces rewriting rules to lower XTensor indexing to Tensor indexing
  • Adds comprehensive indexing tests, including scalar, vector, matrix, boolean, and error cases

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/xtensor/test_math.py New test_scalar_case for scalar-addition in xtensor
tests/xtensor/test_indexing.py New suite of tests covering all indexing modes and error paths
pytensor/xtensor/vectorization.py Guard for empty dims_and_shape when vectorizing scalar outputs
pytensor/xtensor/type.py Shape‐length validation; full __getitem__ and isel impl.; head/tail/thin
pytensor/xtensor/rewriting/indexing.py New op rewrite to lower XTensor indexing into Tensor ops
pytensor/xtensor/rewriting/init.py Registers the new indexing rewrite module
pytensor/xtensor/indexing.py Index Op and helper functions for Python‐level XTensor indexing
pytensor/xtensor/init.py Removed unused XTensorType export
Comments suppressed due to low confidence (3)

pytensor/xtensor/indexing.py:2

  • Typo in the header comment: 'Uselful' should be 'Useful'.
# Uselful links to make sense of all the numpy/xarray complexity

pytensor/xtensor/type.py:52

  • The new constructor check for mismatched dims and shape length should be covered by a test that verifies it raises ValueError when they differ.
if len(self.shape) != len(self.dims):

pytensor/xtensor/init.py:9

  • [nitpick] Removing XTensorType from this module’s exports may break users relying on it; consider keeping a deprecated alias or updating docs.
-    XTensorType,

@OriolAbril
Copy link
Member

I don't think it is urgent but just in case, I don't think I'll be able to review or test it out until next week

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants