Skip to content

Commit 5a7b23c

Browse files
committed
Implement xarray-like Concat
1 parent 3c53271 commit 5a7b23c

File tree

4 files changed

+163
-4
lines changed

4 files changed

+163
-4
lines changed

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
linalg,
66
special,
77
)
8+
from pytensor.xtensor.shape import concat
89
from pytensor.xtensor.type import (
910
XTensorType,
1011
as_xtensor,

pytensor/xtensor/rewriting/shape.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pytensor.graph import node_rewriter
2-
from pytensor.tensor import moveaxis
2+
from pytensor.tensor import broadcast_to, join, moveaxis
33
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
44
from pytensor.xtensor.rewriting.basic import register_xcanonicalize
5-
from pytensor.xtensor.shape import Stack
5+
from pytensor.xtensor.shape import Concat, Stack
66

77

88
@register_xcanonicalize
@@ -27,3 +27,46 @@ def lower_stack(fgraph, node):
2727

2828
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
2929
return [new_out]
30+
31+
32+
@register_xcanonicalize("shape_unsafe")
33+
@node_rewriter(tracks=[Concat])
34+
def lower_concat(fgraph, node):
35+
out_dims = node.outputs[0].type.dims
36+
concat_dim = node.op.dim
37+
concat_axis = out_dims.index(concat_dim)
38+
39+
# Convert input XTensors to Tensors and align batch dimensions
40+
tensor_inputs = []
41+
for inp in node.inputs:
42+
inp_dims = inp.type.dims
43+
order = [
44+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
45+
for out_dim in out_dims
46+
]
47+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
48+
tensor_inputs.append(tensor_inp)
49+
50+
# Broadcast non-concatenated dimensions of each input
51+
non_concat_shape = [None] * len(out_dims)
52+
for tensor_inp in tensor_inputs:
53+
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
54+
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
55+
for i, (bcast, sh) in enumerate(
56+
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
57+
):
58+
if bcast or i == concat_axis or non_concat_shape[i] is not None:
59+
continue
60+
non_concat_shape[i] = sh
61+
62+
assert non_concat_shape.count(None) == 1
63+
64+
bcast_tensor_inputs = []
65+
for tensor_inp in tensor_inputs:
66+
# We modify the concat_axis in place, as we don't need the list anywhere else
67+
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
68+
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
69+
70+
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
71+
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
72+
return [new_out]

pytensor/xtensor/shape.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from collections.abc import Sequence
22

3+
from pytensor import Variable
34
from pytensor.graph import Apply
5+
from pytensor.scalar import upcast
46
from pytensor.xtensor.basic import XOp
57
from pytensor.xtensor.type import as_xtensor, xtensor
68

@@ -69,3 +71,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
6971
)
7072
y = Stack(new_dim_name, tuple(stacked_dims))(y)
7173
return y
74+
75+
76+
class Concat(XOp):
77+
__props__ = ("dim",)
78+
79+
def __init__(self, dim: str):
80+
self.dim = dim
81+
super().__init__()
82+
83+
def make_node(self, *inputs: Variable) -> Apply:
84+
inputs = [as_xtensor(inp) for inp in inputs]
85+
concat_dim = self.dim
86+
87+
dims_and_shape: dict[str, int | None] = {}
88+
for inp in inputs:
89+
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
90+
if dim not in dims_and_shape:
91+
dims_and_shape[dim] = dim_length
92+
else:
93+
if dim == concat_dim:
94+
if dim_length is None:
95+
dims_and_shape[dim] = None
96+
elif dims_and_shape[dim] is not None:
97+
dims_and_shape[dim] += dim_length
98+
elif dim_length is not None:
99+
# Check for conflicting in non-concatenated shapes
100+
if (dims_and_shape[dim] is not None) and (
101+
dims_and_shape[dim] != dim_length
102+
):
103+
raise ValueError(
104+
f"Non-concatenated dimension {dim} has conflicting shapes"
105+
)
106+
# Keep the non-None shape
107+
dims_and_shape[dim] = dim_length
108+
109+
if concat_dim not in dims_and_shape:
110+
# It's a new dim, that should be located at the start
111+
dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape
112+
elif dims_and_shape[concat_dim] is not None:
113+
# We need to add +1 for every input that doesn't have this dimension
114+
for inp in inputs:
115+
if concat_dim not in inp.type.dims:
116+
dims_and_shape[concat_dim] += 1
117+
118+
dims, shape = zip(*dims_and_shape.items())
119+
dtype = upcast(*[x.type.dtype for x in inputs])
120+
output = xtensor(dtype=dtype, dims=dims, shape=shape)
121+
return Apply(self, inputs, [output])
122+
123+
124+
def concat(xtensors, dim: str):
125+
return Concat(dim=dim)(*xtensors)

tests/xtensor/test_shape.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import numpy as np
1010
from xarray import DataArray
11+
from xarray import concat as xr_concat
1112

12-
from pytensor.xtensor.shape import stack
13+
from pytensor.xtensor.shape import concat, stack
1314
from pytensor.xtensor.type import xtensor
14-
from tests.xtensor.util import xr_assert_allclose, xr_function
15+
from tests.xtensor.util import xr_assert_allclose, xr_function, xr_random_like
1516

1617

1718
def powerset(iterable, min_group_size=0):
@@ -102,3 +103,63 @@ def test_multiple_stacks():
102103
res = fn(x_test)
103104
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
104105
xr_assert_allclose(res[0], expected_res)
106+
107+
108+
@pytest.mark.parametrize("dim", ("a", "b", "new"))
109+
def test_concat(dim):
110+
rng = np.random.default_rng(sum(map(ord, dim)))
111+
112+
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
113+
x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2))
114+
115+
x3_shape0 = 4 if dim == "a" else 2
116+
x3_shape1 = 5 if dim == "b" else 3
117+
x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1))
118+
119+
out = concat([x1, x2, x3], dim=dim)
120+
121+
fn = xr_function([x1, x2, x3], out)
122+
x1_test = xr_random_like(x1, rng)
123+
x2_test = xr_random_like(x2, rng)
124+
x3_test = xr_random_like(x3, rng)
125+
126+
res = fn(x1_test, x2_test, x3_test)
127+
expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim)
128+
xr_assert_allclose(res, expected_res)
129+
130+
131+
@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new"))
132+
def test_concat_with_broadcast(dim):
133+
rng = np.random.default_rng(sum(map(ord, dim)) + 1)
134+
135+
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
136+
x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5))
137+
x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7))
138+
x4 = xtensor("x4", dims=(), shape=())
139+
140+
out = concat([x1, x2, x3, x4], dim=dim)
141+
142+
fn = xr_function([x1, x2, x3, x4], out)
143+
144+
x1_test = xr_random_like(x1, rng)
145+
x2_test = xr_random_like(x2, rng)
146+
x3_test = xr_random_like(x3, rng)
147+
x4_test = xr_random_like(x4, rng)
148+
res = fn(x1_test, x2_test, x3_test, x4_test)
149+
expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
150+
xr_assert_allclose(res, expected_res)
151+
152+
153+
def test_concat_scalar():
154+
x1 = xtensor("x1", dims=(), shape=())
155+
x2 = xtensor("x2", dims=(), shape=())
156+
157+
out = concat([x1, x2], dim="new_dim")
158+
159+
fn = xr_function([x1, x2], out)
160+
161+
x1_test = xr_random_like(x1)
162+
x2_test = xr_random_like(x2)
163+
res = fn(x1_test, x2_test)
164+
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
165+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)