Skip to content

Commit b2767de

Browse files
committed
Implement non-boolean index operations for XTensorVariables
1 parent 247da0b commit b2767de

File tree

6 files changed

+664
-3
lines changed

6 files changed

+664
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor import TensorType
11+
from pytensor.tensor.basic import as_tensor
12+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
13+
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
14+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
15+
16+
17+
def as_idx_variable(idx):
18+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
19+
raise TypeError(
20+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
21+
)
22+
if isinstance(idx, slice):
23+
idx = make_slice(idx)
24+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
25+
pass
26+
elif isinstance(idx, tuple) and len(idx) == 2 and isinstance(idx[0], str):
27+
# Special case for ("x", array) that xarray supports
28+
# TODO: Check if this can be used to rename existing xarray dimensions or only for numpy
29+
dim, idx = idx
30+
idx = xtensor_from_tensor(as_tensor(idx), dims=(dim,))
31+
else:
32+
# Must be integer indices, we already counted for None and slices
33+
try:
34+
idx = as_xtensor(idx)
35+
except TypeError:
36+
idx = as_tensor(idx)
37+
if idx.type.dtype == "bool":
38+
raise NotImplementedError("Boolean indexing not yet supported")
39+
if idx.type.dtype not in discrete_dtypes:
40+
raise TypeError("Numerical indices must be integers or boolean")
41+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
42+
# This can't be triggered right now, but will once we lift the boolean restriction
43+
raise NotImplementedError("Scalar boolean indices not supported")
44+
return idx
45+
46+
47+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
48+
if dim_length is None:
49+
return None
50+
if isinstance(slc, Constant):
51+
d = slc.data
52+
start, stop, step = d.start, d.stop, d.step
53+
elif slc.owner is None:
54+
# It's a root variable no way of knowing what we're getting
55+
return None
56+
else:
57+
# It's a MakeSliceOp
58+
start, stop, step = slc.owner.inputs
59+
if isinstance(start, Constant):
60+
start = start.data
61+
else:
62+
return None
63+
if isinstance(stop, Constant):
64+
stop = stop.data
65+
else:
66+
return None
67+
if isinstance(step, Constant):
68+
step = step.data
69+
else:
70+
return None
71+
return len(range(*slice(start, stop, step).indices(dim_length)))
72+
73+
74+
class Index(XOp):
75+
__props__ = ()
76+
77+
def make_node(self, x, *idxs):
78+
x = as_xtensor(x)
79+
idxs = [as_idx_variable(idx) for idx in idxs]
80+
81+
x_ndim = x.type.ndim
82+
x_dims = x.type.dims
83+
x_shape = x.type.shape
84+
out_dims = []
85+
out_shape = []
86+
87+
def combine_dim_info(idx_dim, idx_dim_shape):
88+
if idx_dim not in out_dims:
89+
# First information about the dimension length
90+
out_dims.append(idx_dim)
91+
out_shape.append(idx_dim_shape)
92+
else:
93+
# Dim already introduced in output by a previous index
94+
# Update static shape or raise if incompatible
95+
out_dim_pos = out_dims.index(idx_dim)
96+
out_dim_shape = out_shape[out_dim_pos]
97+
if out_dim_shape is None:
98+
# We don't know the size of the dimension yet
99+
out_shape[out_dim_pos] = idx_dim_shape
100+
elif idx_dim_shape is not None and idx_dim_shape != out_dim_shape:
101+
raise IndexError(
102+
f"Dimension of indexers mismatch for dim {idx_dim}"
103+
)
104+
105+
for i, idx in enumerate(idxs):
106+
if i == x_ndim:
107+
raise IndexError("Too many indices")
108+
if isinstance(idx.type, SliceType):
109+
idx_dim = x_dims[i]
110+
idx_dim_shape = get_static_slice_length(idx, x_shape[i])
111+
combine_dim_info(idx_dim, idx_dim_shape)
112+
else:
113+
if idx.type.ndim == 0:
114+
# Scalar index, dimension is dropped
115+
continue
116+
117+
if isinstance(idx.type, TensorType):
118+
if idx.type.ndim > 1:
119+
# Same error that xarray raises
120+
raise IndexError(
121+
"Unlabeled multi-dimensional array cannot be used for indexing"
122+
)
123+
124+
# This is implicitly an XTensorVariable with dim matching the indexed one
125+
idx = idxs[i] = xtensor_from_tensor(idx, dims=(x_dims[i],))
126+
127+
assert isinstance(idx.type, XTensorType)
128+
129+
idx_dims = idx.type.dims
130+
for idx_dim in idx_dims:
131+
idx_dim_shape = idx.type.shape[idx_dims.index(idx_dim)]
132+
combine_dim_info(idx_dim, idx_dim_shape)
133+
134+
for dim_i, shape_i in zip(x_dims[i + 1 :], x_shape[i + 1 :]):
135+
# Add back any unindexed dimensions
136+
if dim_i not in out_dims:
137+
# If the dimension was not indexed, we keep it as is
138+
combine_dim_info(dim_i, shape_i)
139+
140+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
141+
return Apply(self, [x, *idxs], [output])
142+
143+
144+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from itertools import zip_longest
2+
3+
from pytensor import as_symbolic
4+
from pytensor.graph import Constant, node_rewriter
5+
from pytensor.tensor import TensorType, arange, specify_shape
6+
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing
7+
from pytensor.tensor.type_other import NoneTypeT, SliceType
8+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
9+
from pytensor.xtensor.indexing import Index
10+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
11+
from pytensor.xtensor.type import XTensorType
12+
13+
14+
def to_basic_idx(idx):
15+
if isinstance(idx.type, SliceType):
16+
if isinstance(idx, Constant):
17+
return idx.data
18+
elif idx.owner:
19+
# MakeSlice Op
20+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
21+
return slice(
22+
*[
23+
None if isinstance(i.type, NoneTypeT) else i
24+
for i in idx.owner.inputs
25+
]
26+
)
27+
else:
28+
return idx
29+
if (
30+
isinstance(idx.type, XTensorType)
31+
and idx.type.ndim == 0
32+
and idx.type.dtype != bool
33+
):
34+
return idx.values
35+
raise TypeError("Cannot convert idx to basic idx")
36+
37+
38+
@register_xcanonicalize
39+
@node_rewriter(tracks=[Index])
40+
def lower_index(fgraph, node):
41+
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
42+
43+
xarray-like indexing has two modes:
44+
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices.
45+
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing.
46+
47+
An Index Op can combine both modes.
48+
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing.
49+
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that
50+
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes.
51+
52+
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]],
53+
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have
54+
indices that have more than one dimension at the start.
55+
56+
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension).
57+
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices
58+
we have to convert them slices to equivalent advanced indices.
59+
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
60+
and then indexing it with the original slice. This index is then handled as a regular advanced index.
61+
62+
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
63+
are converted to the appropriate XTensorVariable by `Index.make_node`
64+
"""
65+
66+
x, *idxs = node.inputs
67+
[out] = node.outputs
68+
x_tensor = tensor_from_xtensor(x)
69+
70+
if all(
71+
(
72+
isinstance(idx.type, SliceType)
73+
or (isinstance(idx.type, XTensorType) and idx.type.ndim == 0)
74+
)
75+
for idx in idxs
76+
):
77+
# Special case having just basic indexing
78+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
79+
80+
else:
81+
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
82+
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index
83+
x_dims = x.type.dims
84+
x_shape = tuple(x.shape)
85+
out_ndim = out.type.ndim
86+
out_dims = out.type.dims
87+
aligned_idxs = []
88+
basic_idx_axis = []
89+
# zip_longest adds the implicit slice(None)
90+
for i, (idx, x_dim) in enumerate(
91+
zip_longest(idxs, x_dims, fillvalue=as_symbolic(slice(None)))
92+
):
93+
if isinstance(idx.type, SliceType):
94+
if not any(
95+
(
96+
isinstance(other_idx.type, XTensorType)
97+
and x_dim in other_idx.dims
98+
)
99+
for j, other_idx in enumerate(idxs)
100+
if j != i
101+
):
102+
# We can use basic indexing directly if no other index acts on this dimension
103+
# This is an optimization that avoids creating an unnecessary arange tensor
104+
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
105+
aligned_idxs.append(idx)
106+
basic_idx_axis.append(out_dims.index(x_dim))
107+
else:
108+
# Otherwise we need to convert the basic index into an equivalent advanced indexing
109+
# And align it so it interacts correctly with the other advanced indices
110+
adv_idx_equivalent = arange(x_shape[i])[to_basic_idx(idx)]
111+
ds_order = ["x"] * out_ndim
112+
ds_order[out_dims.index(x_dim)] = 0
113+
aligned_idxs.append(adv_idx_equivalent.dimshuffle(ds_order))
114+
else:
115+
assert isinstance(idx.type, XTensorType)
116+
if idx.type.ndim == 0:
117+
# Scalar index, we can use it directly
118+
aligned_idxs.append(idx.values)
119+
else:
120+
# Vector index, we need to align the indexing dimensions with the base_dims
121+
ds_order = ["x"] * out_ndim
122+
for j, idx_dim in enumerate(idx.dims):
123+
ds_order[out_dims.index(idx_dim)] = j
124+
aligned_idxs.append(idx.values.dimshuffle(ds_order))
125+
126+
# Squeeze indexing dimensions that were not used because we kept basic indexing slices
127+
if basic_idx_axis:
128+
aligned_idxs = [
129+
idx.squeeze(axis=basic_idx_axis)
130+
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
131+
else idx
132+
for idx in aligned_idxs
133+
]
134+
135+
x_tensor_indexed = x_tensor[tuple(aligned_idxs)]
136+
137+
if basic_idx_axis and _non_consecutive_adv_indexing(aligned_idxs):
138+
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
139+
# We need to transpose them back to the expected output order
140+
x_tensor_indexed_basic_dims = [out_dims[idx] for idx in basic_idx_axis]
141+
x_tensor_indexed_dims = [
142+
dim for dim in out_dims if dim not in x_tensor_indexed_basic_dims
143+
] + x_tensor_indexed_basic_dims
144+
transpose_order = [x_tensor_indexed_dims.index(dim) for dim in out_dims]
145+
x_tensor_indexed = x_tensor_indexed.transpose(transpose_order)
146+
147+
# Add lost shape information
148+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
149+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
150+
return [new_out]

0 commit comments

Comments
 (0)