Skip to content

Commit 9cc9f9b

Browse files
committed
Implement Elemwise and Blockwise operations for XTensorVariables
1 parent 3f3fd55 commit 9cc9f9b

File tree

9 files changed

+605
-4
lines changed

9 files changed

+605
-4
lines changed

pytensor/xtensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor import (
5+
linalg,
6+
)
47
from pytensor.xtensor.type import (
58
XTensorType,
69
as_xtensor,

pytensor/xtensor/linalg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
4+
from pytensor.tensor.slinalg import Cholesky, Solve
5+
from pytensor.xtensor.type import as_xtensor
6+
from pytensor.xtensor.vectorization import XBlockwise
7+
8+
9+
def cholesky(
10+
x,
11+
lower: bool = True,
12+
*,
13+
check_finite: bool = False,
14+
overwrite_a: bool = False,
15+
on_error: Literal["raise", "nan"] = "raise",
16+
dims: Sequence[str],
17+
):
18+
if len(dims) != 2:
19+
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
20+
21+
core_op = Cholesky(
22+
lower=lower,
23+
check_finite=check_finite,
24+
overwrite_a=overwrite_a,
25+
on_error=on_error,
26+
)
27+
core_dims = (
28+
((dims[0], dims[1]),),
29+
((dims[0], dims[1]),),
30+
)
31+
x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims)
32+
return x_op(x)
33+
34+
35+
def solve(
36+
a,
37+
b,
38+
dims: Sequence[str],
39+
assume_a="gen",
40+
lower: bool = False,
41+
check_finite: bool = False,
42+
):
43+
a, b = as_xtensor(a), as_xtensor(b)
44+
if len(dims) == 2:
45+
b_ndim = 1
46+
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
47+
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
48+
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
49+
output_core_dims = ((m2_dim,),)
50+
elif len(dims) == 3:
51+
b_ndim = 2
52+
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
53+
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
54+
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
55+
output_core_dims = (
56+
(
57+
m2_dim,
58+
n_dim,
59+
),
60+
)
61+
else:
62+
raise ValueError("Solve dims must have length 2 or 3")
63+
64+
core_op = Solve(
65+
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
66+
)
67+
x_op = XBlockwise(
68+
core_op,
69+
signature=core_op.gufunc_signature,
70+
core_dims=(input_core_dims, output_core_dims),
71+
)
72+
return x_op(a, b)

pytensor/xtensor/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import inspect
2+
import sys
3+
4+
import pytensor.scalar as ps
5+
from pytensor.scalar import ScalarOp
6+
from pytensor.xtensor.vectorization import XElemwise
7+
8+
9+
this_module = sys.modules[__name__]
10+
11+
12+
def get_all_scalar_ops():
13+
"""
14+
Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise.
15+
16+
Returns:
17+
dict: A dictionary mapping operation names to XElemwise instances
18+
"""
19+
result = {}
20+
21+
# Get all module members
22+
for name, obj in inspect.getmembers(ps):
23+
# Check if the object is a scalar op (has make_node method and is not an abstract class)
24+
if isinstance(obj, ScalarOp):
25+
result[name] = XElemwise(obj)
26+
27+
return result
28+
29+
30+
for name, op in get_all_scalar_ops().items():
31+
setattr(this_module, name, op)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import pytensor.xtensor.rewriting.basic
22
import pytensor.xtensor.rewriting.shape
3+
import pytensor.xtensor.rewriting.vectorization
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.blockwise import Blockwise
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
6+
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
7+
8+
9+
@register_xcanonicalize
10+
@node_rewriter(tracks=[XElemwise])
11+
def lower_elemwise(fgraph, node):
12+
out_dims = node.outputs[0].type.dims
13+
14+
# Convert input XTensors to Tensors and align batch dimensions
15+
tensor_inputs = []
16+
for inp in node.inputs:
17+
inp_dims = inp.type.dims
18+
order = [
19+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
20+
for out_dim in out_dims
21+
]
22+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
23+
tensor_inputs.append(tensor_inp)
24+
25+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
26+
*tensor_inputs, return_list=True
27+
)
28+
29+
# Convert output Tensors to XTensors
30+
new_outs = [
31+
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
32+
]
33+
return new_outs
34+
35+
36+
@register_xcanonicalize
37+
@node_rewriter(tracks=[XBlockwise])
38+
def lower_blockwise(fgraph, node):
39+
op: XBlockwise = node.op
40+
batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0])
41+
batch_dims = node.outputs[0].type.dims[:batch_ndim]
42+
43+
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
44+
tensor_inputs = []
45+
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
46+
inp_dims = inp.type.dims
47+
# Align the batch dims of the input, and place the core dims on the right
48+
batch_order = [
49+
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
50+
for batch_dim in batch_dims
51+
]
52+
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
53+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
54+
tensor_inputs.append(tensor_inp)
55+
56+
tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature)
57+
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
58+
59+
# Convert output Tensors to XTensors
60+
new_outs = [
61+
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
62+
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
63+
]
64+
return new_outs

pytensor/xtensor/type.py

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,109 @@ def __complex__(self):
147147
"Call `.astype(complex)` for the symbolic equivalent."
148148
)
149149

150+
# Python valid overloads
151+
def __abs__(self):
152+
return px.math.abs(self)
153+
154+
def __neg__(self):
155+
return px.math.neg(self)
156+
157+
def __lt__(self, other):
158+
return px.math.lt(self, other)
159+
160+
def __le__(self, other):
161+
return px.math.le(self, other)
162+
163+
def __gt__(self, other):
164+
return px.math.gt(self, other)
165+
166+
def __ge__(self, other):
167+
return px.math.ge(self, other)
168+
169+
def __invert__(self):
170+
return px.math.invert(self)
171+
172+
def __and__(self, other):
173+
return px.math.and_(self, other)
174+
175+
def __or__(self, other):
176+
return px.math.or_(self, other)
177+
178+
def __xor__(self, other):
179+
return px.math.xor(self, other)
180+
181+
def __rand__(self, other):
182+
return px.math.and_(other, self)
183+
184+
def __ror__(self, other):
185+
return px.math.or_(other, self)
186+
187+
def __rxor__(self, other):
188+
return px.math.xor(other, self)
189+
190+
def __add__(self, other):
191+
return px.math.add(self, other)
192+
193+
def __sub__(self, other):
194+
return px.math.sub(self, other)
195+
196+
def __mul__(self, other):
197+
return px.math.mul(self, other)
198+
199+
def __div__(self, other):
200+
return px.math.div(self, other)
201+
202+
def __pow__(self, other):
203+
return px.math.pow(self, other)
204+
205+
def __mod__(self, other):
206+
return px.math.mod(self, other)
207+
208+
def __divmod__(self, other):
209+
return px.math.divmod(self, other)
210+
211+
def __truediv__(self, other):
212+
return px.math.true_div(self, other)
213+
214+
def __floordiv__(self, other):
215+
return px.math.floor_div(self, other)
216+
217+
def __rtruediv__(self, other):
218+
return px.math.true_div(other, self)
219+
220+
def __rfloordiv__(self, other):
221+
return px.math.floor_div(other, self)
222+
223+
def __radd__(self, other):
224+
return px.math.add(other, self)
225+
226+
def __rsub__(self, other):
227+
return px.math.sub(other, self)
228+
229+
def __rmul__(self, other):
230+
return px.math.mul(other, self)
231+
232+
def __rdiv__(self, other):
233+
return px.math.div_proxy(other, self)
234+
235+
def __rmod__(self, other):
236+
return px.math.mod(other, self)
237+
238+
def __rdivmod__(self, other):
239+
return px.math.divmod(other, self)
240+
241+
def __rpow__(self, other):
242+
return px.math.pow(other, self)
243+
244+
def __ceil__(self):
245+
return px.math.ceil(self)
246+
247+
def __floor__(self):
248+
return px.math.floor(self)
249+
250+
def __trunc__(self):
251+
return px.math.trunc(self)
252+
150253
# DataArray-like attributes
151254
# https://docs.xarray.dev/en/latest/api.html#id1
152255
@property
@@ -193,14 +296,33 @@ def dtype(self):
193296

194297
# DataArray contents
195298
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
196-
def rename(self, new_name_or_name_dict, **names):
299+
def rename(self, new_name_or_name_dict=None, **names):
197300
if isinstance(new_name_or_name_dict, str):
198-
# TODO: Should we make a symbolic copy?
199-
self.name = new_name_or_name_dict
301+
new_name = new_name_or_name_dict
200302
name_dict = None
201303
else:
304+
new_name = None
202305
name_dict = new_name_or_name_dict
203-
return px.basic.rename(name_dict, **names)
306+
new_out = px.basic.rename(self, name_dict, **names)
307+
new_out.name = new_name
308+
return new_out
309+
310+
# def swap_dims(self, *args, **kwargs):
311+
# ...
312+
#
313+
# def expand_dims(self, *args, **kwargs):
314+
# ...
315+
#
316+
# def squeeze(self):
317+
# ...
318+
319+
def copy(self, name: str | None = None):
320+
out = px.math.identity(self)
321+
out.name = name
322+
return out
323+
324+
def astype(self, dtype):
325+
return px.math.cast(self, dtype)
204326

205327
def item(self):
206328
raise NotImplementedError("item not implemented for XTensorVariable")
@@ -220,6 +342,26 @@ def sel(self, *args, **kwargs):
220342
def __getitem__(self, idx):
221343
raise NotImplementedError("Indexing not yet implemnented")
222344

345+
# ndarray methods
346+
# https://docs.xarray.dev/en/latest/api.html#id7
347+
def clip(self, min, max):
348+
return px.math.clip(self, min, max)
349+
350+
def conj(self):
351+
return px.math.conj(self)
352+
353+
@property
354+
def imag(self):
355+
return px.math.imag(self)
356+
357+
@property
358+
def real(self):
359+
return px.math.real(self)
360+
361+
# @property
362+
# def T(self):
363+
# ...
364+
223365

224366
class XTensorConstantSignature(tuple):
225367
def __eq__(self, other):

0 commit comments

Comments
 (0)