Skip to content

Commit 3f3fd55

Browse files
committed
WIP Basic labeled tensor functionality
TODO: Split Stack from commit
1 parent 13dc0d4 commit 3f3fd55

File tree

11 files changed

+761
-0
lines changed

11 files changed

+761
-0
lines changed

pytensor/xtensor/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import warnings
2+
3+
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor.type import (
5+
XTensorType,
6+
as_xtensor,
7+
xtensor,
8+
xtensor_constant,
9+
)
10+
11+
12+
warnings.warn("xtensor module is experimental and full of bugs")

pytensor/xtensor/basic.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from pytensor.graph import Apply, Op
2+
from pytensor.tensor.type import TensorType
3+
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
4+
5+
6+
class XOp(Op):
7+
"""A base class for XOps that shouldn't be materialized"""
8+
9+
def perform(self, node, inputs, outputs):
10+
raise NotImplementedError(
11+
f"xtensor operation {self} must be lowered to equivalent tensor operations"
12+
)
13+
14+
15+
class XViewOp(Op):
16+
# Make this a View Op with C-implementation
17+
view_map = {0: [0]}
18+
19+
def perform(self, node, inputs, output_storage):
20+
output_storage[0][0] = inputs[0]
21+
22+
23+
class TensorFromXTensor(XViewOp):
24+
__props__ = ()
25+
26+
def make_node(self, x) -> Apply:
27+
if not isinstance(x.type, XTensorType):
28+
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
29+
output = TensorType(x.type.dtype, shape=x.type.shape)()
30+
return Apply(self, [x], [output])
31+
32+
33+
tensor_from_xtensor = TensorFromXTensor()
34+
35+
36+
class XTensorFromTensor(XViewOp):
37+
__props__ = ("dims",)
38+
39+
def __init__(self, dims):
40+
super().__init__()
41+
self.dims = dims
42+
43+
def make_node(self, x) -> Apply:
44+
if not isinstance(x.type, TensorType):
45+
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
46+
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
47+
return Apply(self, [x], [output])
48+
49+
50+
def xtensor_from_tensor(x, dims):
51+
return XTensorFromTensor(dims=dims)(x)
52+
53+
54+
class Rename(XViewOp):
55+
__props__ = ("new_dims",)
56+
57+
def __init__(self, new_dims: tuple[str, ...]):
58+
super().__init__()
59+
self.new_dims = new_dims
60+
61+
def make_node(self, x):
62+
x = as_xtensor(x)
63+
output = x.type.clone(dims=self.new_dims)()
64+
return Apply(self, [x], [output])
65+
66+
67+
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
68+
if name_dict is not None:
69+
if names:
70+
raise ValueError("Cannot use both positional and keyword names in rename")
71+
names = name_dict
72+
73+
x = as_xtensor(x)
74+
old_names = x.type.dims
75+
new_names = list(old_names)
76+
for old_name, new_name in names.items():
77+
try:
78+
new_names[old_names.index(old_name)] = new_name
79+
except IndexError:
80+
raise ValueError(
81+
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
82+
)
83+
84+
return Rename(tuple(new_names))(x)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.shape

pytensor/xtensor/rewriting/basic.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.xtensor.basic import (
3+
Rename,
4+
TensorFromXTensor,
5+
XTensorFromTensor,
6+
)
7+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
8+
9+
10+
@register_xcanonicalize
11+
@node_rewriter(tracks=[TensorFromXTensor])
12+
def useless_tensor_from_xtensor(fgraph, node):
13+
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
14+
[x] = node.inputs
15+
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
16+
return [x.owner.inputs[0]]
17+
18+
19+
@register_xcanonicalize
20+
@node_rewriter(tracks=[XTensorFromTensor])
21+
def useless_xtensor_from_tensor(fgraph, node):
22+
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
23+
[x] = node.inputs
24+
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
25+
return [x.owner.inputs[0]]
26+
27+
28+
@register_xcanonicalize
29+
@node_rewriter(tracks=[TensorFromXTensor])
30+
def useless_tensor_from_xtensor_of_rename(fgraph, node):
31+
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
32+
[renamed_x] = node.inputs
33+
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
34+
[x] = renamed_x.owner.inputs
35+
return node.op(x, return_list=True)
36+
37+
38+
@register_xcanonicalize
39+
@node_rewriter(tracks=[Rename])
40+
def useless_rename(fgraph, node):
41+
"""
42+
43+
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
44+
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
45+
"""
46+
[renamed_x] = node.inputs
47+
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename | XTensorFromTensor):
48+
[x] = renamed_x.owner.inputs
49+
return node.op(x, return_list=True)

pytensor/xtensor/rewriting/shape.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor import moveaxis
3+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
4+
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5+
from pytensor.xtensor.shape import Stack
6+
7+
8+
@register_xcanonicalize
9+
@node_rewriter(tracks=[Stack])
10+
def lower_stack(fgraph, node):
11+
[x] = node.inputs
12+
batch_ndim = x.type.ndim - len(node.op.stacked_dims)
13+
stacked_axes = [
14+
i for i, dim in enumerate(x.type.dims) if dim in node.op.stacked_dims
15+
]
16+
end = tuple(range(-len(stacked_axes), 0))
17+
18+
x_tensor = tensor_from_xtensor(x)
19+
x_tensor_transposed = moveaxis(x_tensor, source=stacked_axes, destination=end)
20+
if batch_ndim == (x.type.ndim - 1):
21+
# This happens when we stack a "single" dimension, in this case all we need is the transpose
22+
# Note: If we have meaningful rewrites before lowering, consider canonicalizing this as a Transpose + Rename
23+
final_tensor = x_tensor_transposed
24+
else:
25+
final_shape = (*tuple(x_tensor_transposed.shape)[:batch_ndim], -1)
26+
final_tensor = x_tensor_transposed.reshape(final_shape)
27+
28+
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
29+
return [new_out]

pytensor/xtensor/rewriting/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph.rewriting.basic import NodeRewriter
3+
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
5+
6+
optdb.register(
7+
"xcanonicalize",
8+
EquilibriumDB(ignore_newtrees=False),
9+
"fast_run",
10+
"fast_compile",
11+
"xtensor",
12+
position=0,
13+
)
14+
15+
16+
def register_xcanonicalize(
17+
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
18+
):
19+
if isinstance(node_rewriter, str):
20+
21+
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
22+
return register_xcanonicalize(
23+
inner_rewriter, node_rewriter, *tags, **kwargs
24+
)
25+
26+
return register
27+
28+
else:
29+
name = kwargs.pop("name", None) or node_rewriter.__name__
30+
optdb["xtensor"].register(
31+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
32+
)
33+
return node_rewriter

pytensor/xtensor/shape.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from collections.abc import Sequence
2+
3+
from pytensor.graph import Apply
4+
from pytensor.xtensor.basic import XOp
5+
from pytensor.xtensor.type import as_xtensor, xtensor
6+
7+
8+
class Stack(XOp):
9+
__props__ = ("new_dim_name", "stacked_dims")
10+
11+
def __init__(self, new_dim_name: str, stacked_dims: tuple[str, ...]):
12+
super().__init__()
13+
if new_dim_name in stacked_dims:
14+
raise ValueError(
15+
f"Stacking dim {new_dim_name} must not be in {stacked_dims}"
16+
)
17+
if not stacked_dims:
18+
raise ValueError(f"Stacking dims must not be empty: got {stacked_dims}")
19+
self.new_dim_name = new_dim_name
20+
self.stacked_dims = stacked_dims
21+
22+
def make_node(self, x):
23+
x = as_xtensor(x)
24+
if not (set(self.stacked_dims) <= set(x.type.dims)):
25+
raise ValueError(
26+
f"Stacking dims {self.stacked_dims} must be a subset of {x.type.dims}"
27+
)
28+
if self.new_dim_name in x.type.dims:
29+
raise ValueError(
30+
f"Stacking dim {self.new_dim_name} must not be in {x.type.dims}"
31+
)
32+
if len(self.stacked_dims) == x.type.ndim:
33+
batch_dims, batch_shape = (), ()
34+
else:
35+
batch_dims, batch_shape = zip(
36+
*(
37+
(dim, shape)
38+
for dim, shape in zip(x.type.dims, x.type.shape)
39+
if dim not in self.stacked_dims
40+
)
41+
)
42+
stack_shape = 1
43+
for dim, shape in zip(x.type.dims, x.type.shape):
44+
if dim in self.stacked_dims:
45+
if shape is None:
46+
stack_shape = None
47+
break
48+
else:
49+
stack_shape *= shape
50+
output = xtensor(
51+
dtype=x.type.dtype,
52+
shape=(*batch_shape, stack_shape),
53+
dims=(*batch_dims, self.new_dim_name),
54+
)
55+
return Apply(self, [x], [output])
56+
57+
58+
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
59+
if dim is not None:
60+
if dims:
61+
raise ValueError("Cannot use both positional dim and keyword dims in stack")
62+
dims = dim
63+
64+
y = x
65+
for new_dim_name, stacked_dims in dims.items():
66+
if isinstance(stacked_dims, str):
67+
raise TypeError(
68+
f"Stacking dims must be a sequence of strings, got a single string: {stacked_dims}"
69+
)
70+
y = Stack(new_dim_name, tuple(stacked_dims))(y)
71+
return y

0 commit comments

Comments
 (0)