-
Notifications
You must be signed in to change notification settings - Fork 132
Implement indexing operations for XTensorVariables #1429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Conversation
29b954a
to
5a7b23c
Compare
75225f0
to
2c47444
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements full indexing support for XTensorVariables by adding indexing operations (__getitem__
, isel
, head
/tail
/thin
), lowering logic in the rewrite pass, and comprehensive tests.
- Add a new suite of indexing tests in
tests/xtensor/test_indexing.py
and cover scalar case intest_math.py
- Implement
__getitem__
,isel
, and related slicing utilities inpytensor/xtensor/type.py
- Introduce an
Index
op, lowering logic (rewriting/indexing.py
), and integrate with vectorization fallbacks
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
tests/xtensor/test_math.py | Add test_scalar_case for zero-dimensional tensors |
tests/xtensor/test_indexing.py | New tests covering basic, advanced, vector, matrix, and scalar indexing |
pytensor/xtensor/vectorization.py | Handle empty dims_and_shape in vectorization |
pytensor/xtensor/type.py | Implement __getitem__ , isel , and dimension checks |
pytensor/xtensor/indexing.py | New Index op and helpers for converting indices |
pytensor/xtensor/rewriting/indexing.py | Lower XTensor indexing to TensorVariable indexing |
pytensor/xtensor/rewriting/init.py | Register the new indexing rewriter |
pytensor/xtensor/init.py | Remove unused XTensorType import |
Comments suppressed due to low confidence (2)
tests/xtensor/test_indexing.py:32
- [nitpick] The variable name
shufled_dims
is misspelled and ambiguous; consider renaming it toshuffled_dims
.
shufled_dims = tuple(np.random.permutation(dims))
pytensor/xtensor/indexing.py:3
- Typo in comment:
Uselful
should beUseful
.
# Uselful links to make sense of all the numpy/xarray complexity
2c47444
to
b2767de
Compare
b2767de
to
ec79aa5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I read through the code and couldn't think of cases that would break, I also played for a while and everything seemed to work. I left some minor examples.
The main comment is I think we need more tests. The existing ones cover most branches already though. Happy to add some more here or later on.
I agree, converting the provided boolean indexes to their integer equivalent as first step should be the only thing needed to support them. |
Extra tests welcome @OriolAbril just push |
Do we even need to convert? |
I guess it will depend on what exactly the lowering does. For example: # obs is DataArray (chain: 4, draw: 500, obs_id: 77)
# chain_idxs, draw_idxs are length 200 integer DataArrays with dims ("divergence_id",)
# mask is a length 77 boolean array with 16 Trues in it.
# This works
obs.isel(chain=chain_idxs, draw=draw_idxs, obs_id=mask)
# obs[chain_idxs, draw_idxs, mask]
# but both attempts below fail
obs.values[chain_idxs.values, draw_idxs.values, mask]
# IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (200,) (200,) (16,)
obs.values[chain_idxs.values[:, None], draw_idxs.values[:, None], mask[None, :]]
# IndexError: too many indices IIUC, the current lowering will generate indexes like I do in the 2nd failed attempt and forward to tensor indexing. On the other hand the snippet below does work: mask_idx, = np.nonzero(mask)
obs.values[chain_idxs.values[:, None], draw_idxs.values[:, None], mask_idx[None, :]] |
@OriolAbril thanks for checking, it seems we need to convert to integers then for the general case. I'm gonna see if I'm missing something and/or if it's worth to keep booleans when they are the only advanced index group. Either way I would leave it for a separate PR, as tests will certainly be needed. |
Another point in favor is that So now the question is only if we want to ever not do it? Maybe we add a special branch if there's a single bool index and there's no other advanced indexing going on (or at least no other advanced orthogonal indexing)? |
It looks like that is the best way to go. From https://numpy.org/devdocs//user/basics.indexing.html#boolean-array-indexing:
|
but also with xarray rules on boolean indexes |
Yeah but the following are fine: import numpy as np
import xarray as xr
x = xr.DataArray(np.zeros((5, 5)), dims=("a", "b"))
x.values[np.array([True, False, False, True, True]), 1:]
x.values[np.array([True, False, False, True, True]), 0]
x.values[np.array([True, False, False, True, True]), [0, 2, 4]] Now I think in the last case numpy will convert the boolean array to There's also the issue that we are not using the C implementation of NonZero yet, so it adds a tiny bit extra overhead in the default backend: #1361 OTOH if |
I did some benchmarks and boolean indexing doesn't seem to be any faster, at least the way PyTensor uses it so I suggest we actually get rid of them and always use nonzero, both here and later in tensor operations as well: #1432 |
@OriolAbril I went and lifted the boolean restriction, always converting it with nonzero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds full non-boolean indexing support for XTensorVariables and related tests
- Implements
__getitem__
,isel
,head
/tail
/thin
onXTensorVariable
with ellipsis handling - Introduces rewriting rules to lower XTensor indexing to Tensor indexing
- Adds comprehensive indexing tests, including scalar, vector, matrix, boolean, and error cases
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/xtensor/test_math.py | New test_scalar_case for scalar-addition in xtensor |
tests/xtensor/test_indexing.py | New suite of tests covering all indexing modes and error paths |
pytensor/xtensor/vectorization.py | Guard for empty dims_and_shape when vectorizing scalar outputs |
pytensor/xtensor/type.py | Shape‐length validation; full __getitem__ and isel impl.; head/tail/thin |
pytensor/xtensor/rewriting/indexing.py | New op rewrite to lower XTensor indexing into Tensor ops |
pytensor/xtensor/rewriting/init.py | Registers the new indexing rewrite module |
pytensor/xtensor/indexing.py | Index Op and helper functions for Python‐level XTensor indexing |
pytensor/xtensor/init.py | Removed unused XTensorType export |
Comments suppressed due to low confidence (3)
pytensor/xtensor/indexing.py:2
- Typo in the header comment: 'Uselful' should be 'Useful'.
# Uselful links to make sense of all the numpy/xarray complexity
pytensor/xtensor/type.py:52
- The new constructor check for mismatched
dims
andshape
length should be covered by a test that verifies it raisesValueError
when they differ.
if len(self.shape) != len(self.dims):
pytensor/xtensor/init.py:9
- [nitpick] Removing
XTensorType
from this module’s exports may break users relying on it; consider keeping a deprecated alias or updating docs.
- XTensorType,
I don't think it is urgent but just in case, I don't think I'll be able to review or test it out until next week |
Split the work on indexing from #1411
This implements all forms of non-boolean indexing in xarray. I think the same framework works for boolean indices which are restricted to be vectors in xarray, but I haven't implemented those yet.
It deviates from xarray in that it doesn't try to prevent runtime broadcasting between indexers when lowering to tensor index operations. There's a test marked as xfail. I'm weary of complicating the graph with a bunch of
specify_shape
but it's something we can consider doing if users complain.It does prevents compile-time known broadcasting (so if everything has static shapes it will error).