Skip to content

Commit e3f2324

Browse files
committed
WIP Implement index operations for XTensorVariables
1 parent 7edc5a9 commit e3f2324

File tree

6 files changed

+361
-3
lines changed

6 files changed

+361
-3
lines changed

pytensor/xtensor/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
)
88
from pytensor.xtensor.shape import concat
99
from pytensor.xtensor.type import (
10-
XTensorType,
1110
as_xtensor,
1211
xtensor,
1312
xtensor_constant,

pytensor/xtensor/indexing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# HERE LIE DRAGONS
2+
# Uselful links to make sense of all the numpy/xarray complexity
3+
# https://numpy.org/devdocs//user/basics.indexing.html
4+
# https://numpy.org/neps/nep-0021-advanced-indexing.html
5+
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
6+
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
7+
8+
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.scalar.basic import discrete_dtypes
10+
from pytensor.tensor.basic import as_tensor
11+
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
12+
from pytensor.xtensor.basic import XOp
13+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
14+
15+
16+
def as_idx_variable(idx):
17+
if idx is None or (isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT)):
18+
raise TypeError(
19+
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
20+
)
21+
if isinstance(idx, slice):
22+
idx = make_slice(idx)
23+
elif isinstance(idx, Variable) and isinstance(idx.type, SliceType):
24+
pass
25+
else:
26+
# Must be integer indices, we already counted for None and slices
27+
try:
28+
idx = as_tensor(idx)
29+
except TypeError:
30+
idx = as_xtensor(idx)
31+
if idx.type.dtype == "bool":
32+
raise NotImplementedError("Boolean indexing not yet supported")
33+
if idx.type.dtype not in discrete_dtypes:
34+
raise TypeError("Numerical indices must be integers or boolean")
35+
if idx.type.dtype == "bool" and idx.type.ndim == 0:
36+
# This can't be triggered right now, but will once we lift the boolean restriction
37+
raise NotImplementedError("Scalar boolean indices not supported")
38+
return idx
39+
40+
41+
def get_static_slice_length(slc: Variable, dim_length: None | int) -> int | None:
42+
if dim_length is None:
43+
return None
44+
if isinstance(slc, Constant):
45+
d = slc.data
46+
start, stop, step = d.start, d.stop, d.step
47+
elif slc.owner is None:
48+
# It's a root variable no way of knowing what we're getting
49+
return None
50+
else:
51+
# It's a MakeSliceOp
52+
start, stop, step = slc.owner.inputs
53+
if isinstance(start, Constant):
54+
start = start.data
55+
else:
56+
return None
57+
if isinstance(stop, Constant):
58+
stop = stop.data
59+
else:
60+
return None
61+
if isinstance(step, Constant):
62+
step = step.data
63+
else:
64+
return None
65+
return len(range(*slice(start, stop, step).indices(dim_length)))
66+
67+
68+
class Index(XOp):
69+
__props__ = ()
70+
71+
def make_node(self, x, *idxs):
72+
x = as_xtensor(x)
73+
idxs = [as_idx_variable(idx) for idx in idxs]
74+
75+
x_ndim = x.type.ndim
76+
x_dims = x.type.dims
77+
x_shape = x.type.shape
78+
out_dims = []
79+
out_shape = []
80+
has_unlabeled_vector_idx = False
81+
has_labeled_vector_idx = False
82+
for i, idx in enumerate(idxs):
83+
if i == x_ndim:
84+
raise IndexError("Too many indices")
85+
if isinstance(idx.type, SliceType):
86+
out_dims.append(x_dims[i])
87+
out_shape.append(get_static_slice_length(idx, x_shape[i]))
88+
elif isinstance(idx.type, XTensorType):
89+
if has_unlabeled_vector_idx:
90+
raise NotImplementedError(
91+
"Mixing of labeled and unlabeled vector indexing not implemented"
92+
)
93+
has_labeled_vector_idx = True
94+
idx_dims = idx.type.dims
95+
for dim in idx_dims:
96+
idx_dim_shape = idx.type.shape[idx_dims.index(dim)]
97+
if dim in out_dims:
98+
# Dim already introduced in output by a previous index
99+
# Update static shape or raise if incompatible
100+
out_dim_pos = out_dims.index(dim)
101+
out_dim_shape = out_shape[out_dim_pos]
102+
if out_dim_shape is None:
103+
# We don't know the size of the dimension yet
104+
out_shape[out_dim_pos] = idx_dim_shape
105+
elif (
106+
idx_dim_shape is not None and idx_dim_shape != out_dim_shape
107+
):
108+
raise IndexError(
109+
f"Dimension of indexers mismatch for dim {dim}"
110+
)
111+
else:
112+
# New dimension
113+
out_dims.append(dim)
114+
out_shape.append(idx_dim_shape)
115+
116+
else: # TensorType
117+
if idx.type.ndim == 0:
118+
# Scalar, dimension is dropped
119+
pass
120+
elif idx.type.ndim == 1:
121+
if has_labeled_vector_idx:
122+
raise NotImplementedError(
123+
"Mixing of labeled and unlabeled vector indexing not implemented"
124+
)
125+
has_unlabeled_vector_idx = True
126+
out_dims.append(x_dims[i])
127+
out_shape.append(idx.type.shape[0])
128+
else:
129+
# Same error that xarray raises
130+
raise IndexError(
131+
"Unlabeled multi-dimensional array cannot be used for indexing"
132+
)
133+
for j in range(i + 1, x_ndim):
134+
# Add any unindexed dimensions
135+
out_dims.append(x_dims[j])
136+
out_shape.append(x_shape[j])
137+
138+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
139+
return Apply(self, [x, *idxs], [output])
140+
141+
142+
index = Index()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.indexing
23
import pytensor.xtensor.rewriting.reduction
34
import pytensor.xtensor.rewriting.shape
45
import pytensor.xtensor.rewriting.vectorization
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from pytensor.graph import Constant, node_rewriter
2+
from pytensor.tensor import TensorType, specify_shape
3+
from pytensor.tensor.type_other import NoneTypeT, SliceType
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.indexing import Index
6+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
7+
from pytensor.xtensor.type import XTensorType
8+
9+
10+
def to_basic_idx(idx):
11+
if isinstance(idx.type, SliceType):
12+
if isinstance(idx, Constant):
13+
return idx.data
14+
elif idx.owner:
15+
# MakeSlice Op
16+
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
17+
return slice(
18+
*[
19+
None if isinstance(i.type, NoneTypeT) else i
20+
for i in idx.owner.inputs
21+
]
22+
)
23+
else:
24+
return idx
25+
if (
26+
isinstance(idx.type, XTensorType | TensorType)
27+
and idx.type.ndim == 0
28+
and idx.type.dtype != bool
29+
):
30+
return idx
31+
raise TypeError("Cannot convert idx to basic idx")
32+
33+
34+
def _count_idx_types(idxs):
35+
basic, vector, xvector = 0, 0, 0
36+
for idx in idxs:
37+
if isinstance(idx.type, SliceType):
38+
basic += 1
39+
elif idx.type.ndim == 0:
40+
basic += 1
41+
elif isinstance(idx.type, TensorType):
42+
vector += 1
43+
else:
44+
xvector += 1
45+
return basic, vector, xvector
46+
47+
48+
@register_xcanonicalize
49+
@node_rewriter(tracks=[Index])
50+
def lower_index(fgraph, node):
51+
x, *idxs = node.inputs
52+
[out] = node.outputs
53+
x_tensor = tensor_from_xtensor(x)
54+
n_basic, n_vector, n_xvector = _count_idx_types(idxs)
55+
if n_xvector == 0 and n_vector == 0:
56+
x_tensor_indexed = x_tensor[tuple(to_basic_idx(idx) for idx in idxs)]
57+
elif n_vector == 1 and n_xvector == 0:
58+
# Special case for single vector index, no orthogonal indexing
59+
x_tensor_indexed = x_tensor[tuple(idxs)]
60+
else:
61+
# Not yet implemented
62+
return None
63+
64+
# Add lost shape if any
65+
x_tensor_indexed = specify_shape(x_tensor_indexed, out.type.shape)
66+
new_out = xtensor_from_tensor(x_tensor_indexed, dims=out.type.dims)
67+
return [new_out]

pytensor/xtensor/type.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from pytensor.tensor import TensorType
24
from pytensor.tensor.math import variadic_mul
35

@@ -10,7 +12,7 @@
1012
XARRAY_AVAILABLE = False
1113

1214
from collections.abc import Sequence
13-
from typing import TypeVar
15+
from typing import Any, Literal, TypeVar
1416

1517
import numpy as np
1618

@@ -340,7 +342,112 @@ def sel(self, *args, **kwargs):
340342
raise NotImplementedError("sel not implemented for XTensorVariable")
341343

342344
def __getitem__(self, idx):
343-
raise NotImplementedError("Indexing not yet implemnented")
345+
if isinstance(idx, dict):
346+
return self.isel(idx)
347+
348+
# Check for ellipsis not in the last position (last one is useless anyway)
349+
if any(idx_item is Ellipsis for idx_item in idx):
350+
if idx.count(Ellipsis) > 1:
351+
raise IndexError("an index can only have a single ellipsis ('...')")
352+
# Convert intermediate Ellipsis to slice(None)
353+
ellipsis_loc = idx.index(Ellipsis)
354+
n_implied_none_slices = self.type.ndim - (len(idx) - 1)
355+
idx = (
356+
*idx[:ellipsis_loc],
357+
*((slice(None),) * n_implied_none_slices),
358+
*idx[ellipsis_loc + 1 :],
359+
)
360+
361+
return px.indexing.index(self, *idx)
362+
363+
def isel(
364+
self,
365+
indexers: dict[str, Any] | None = None,
366+
drop: bool = False, # Unused by PyTensor
367+
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
368+
**indexers_kwargs,
369+
):
370+
if indexers_kwargs:
371+
if indexers is not None:
372+
raise ValueError(
373+
"Cannot pass both indexers and indexers_kwargs to isel"
374+
)
375+
indexers = indexers_kwargs
376+
377+
if missing_dims not in {"raise", "warn", "ignore"}:
378+
raise ValueError(
379+
f"Unrecognized options {missing_dims} for missing_dims argument"
380+
)
381+
382+
# Sort indices and pass them to index
383+
dims = self.type.dims
384+
indices = [slice(None)] * self.type.ndim
385+
for key, idx in indexers.items():
386+
if idx is Ellipsis:
387+
# Xarray raises a less informative error, suggesting indices must be integer
388+
# But slices are also fine
389+
raise TypeError("Ellipsis (...) is an invalid labeled index")
390+
try:
391+
indices[dims.index(key)] = idx
392+
except IndexError:
393+
if missing_dims == "raise":
394+
raise ValueError(
395+
f"Dimension {key} does not exist. Expected one of {dims}"
396+
)
397+
elif missing_dims == "warn":
398+
warnings.warn(
399+
UserWarning,
400+
f"Dimension {key} does not exist. Expected one of {dims}",
401+
)
402+
403+
return px.indexing.index(self, *indices)
404+
405+
def _head_tail_or_thin(
406+
self,
407+
indexers: dict[str, Any] | int | None,
408+
indexers_kwargs: dict[str, Any],
409+
*,
410+
kind: Literal["head", "tail", "thin"],
411+
):
412+
if indexers_kwargs:
413+
if indexers is not None:
414+
raise ValueError(
415+
"Cannot pass both indexers and indexers_kwargs to head"
416+
)
417+
indexers = indexers_kwargs
418+
419+
if indexers is None:
420+
if kind == "thin":
421+
raise TypeError(
422+
"thin() indexers must be either dict-like or a single integer"
423+
)
424+
else:
425+
# Default to 5 for head and tail
426+
indexers = {dim: 5 for dim in self.type.dims}
427+
428+
elif not isinstance(indexers, dict):
429+
indexers = {dim: indexers for dim in self.type.dims}
430+
431+
if kind == "head":
432+
indices = {dim: slice(None, value) for dim, value in indexers.items()}
433+
elif kind == "tail":
434+
sizes = self.sizes
435+
# Can't use slice(-value, None), in case value is zero
436+
indices = {
437+
dim: slice(sizes[dim] - value, None) for dim, value in indexers.items()
438+
}
439+
elif kind == "thin":
440+
indices = {dim: slice(None, None, value) for dim, value in indexers.items()}
441+
return self.isel(indices)
442+
443+
def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
444+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head")
445+
446+
def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
447+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail")
448+
449+
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
450+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
344451

345452
# ndarray methods
346453
# https://docs.xarray.dev/en/latest/api.html#id7

tests/xtensor/test_indexing.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import pytest
3+
from xarray import DataArray
4+
5+
from pytensor.xtensor import xtensor
6+
from tests.xtensor.util import xr_assert_allclose, xr_function
7+
8+
9+
@pytest.mark.parametrize(
10+
"indices",
11+
[
12+
(0,),
13+
(slice(1, None),),
14+
(slice(None, -1),),
15+
(slice(None, None, -1),),
16+
(0, slice(None), -1, slice(1, None)),
17+
(..., 0, -1),
18+
(0, ..., -1),
19+
(0, -1, ...),
20+
],
21+
)
22+
@pytest.mark.parametrize("labeled", (False, True), ids=["unlabeled", "labeled"])
23+
def test_basic_indexing(labeled, indices):
24+
if ... in indices and labeled:
25+
pytest.skip("Ellipsis not supported with labeled indexing")
26+
27+
dims = ("a", "b", "c", "d")
28+
x = xtensor(dims=dims, shape=(2, 3, 5, 7))
29+
30+
if labeled:
31+
shufled_dims = tuple(np.random.permutation(dims))
32+
indices = dict(zip(shufled_dims, indices, strict=False))
33+
out = x[indices]
34+
35+
fn = xr_function([x], out)
36+
x_test_values = np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(
37+
x.type.shape
38+
)
39+
x_test = DataArray(x_test_values, dims=x.type.dims)
40+
res = fn(x_test)
41+
expected_res = x_test[indices]
42+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)