Skip to content

Commit c3da710

Browse files
committed
Implement Elemwise and Blockwise operations for XTensorVariables
1 parent 2054ad8 commit c3da710

File tree

9 files changed

+617
-4
lines changed

9 files changed

+617
-4
lines changed

pytensor/xtensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4+
from pytensor.xtensor import (
5+
linalg,
6+
)
47
from pytensor.xtensor.type import (
58
XTensorType,
69
as_xtensor,

pytensor/xtensor/linalg.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Sequence
2+
from typing import Literal
3+
4+
from pytensor.tensor.slinalg import Cholesky, Solve
5+
from pytensor.xtensor.type import as_xtensor
6+
from pytensor.xtensor.vectorization import XBlockwise
7+
8+
9+
def cholesky(
10+
x,
11+
lower: bool = True,
12+
*,
13+
check_finite: bool = False,
14+
overwrite_a: bool = False,
15+
on_error: Literal["raise", "nan"] = "raise",
16+
dims: Sequence[str],
17+
):
18+
if len(dims) != 2:
19+
raise ValueError(f"Cholesky needs two dims, got {len(dims)}")
20+
21+
core_op = Cholesky(
22+
lower=lower,
23+
check_finite=check_finite,
24+
overwrite_a=overwrite_a,
25+
on_error=on_error,
26+
)
27+
core_dims = (
28+
((dims[0], dims[1]),),
29+
((dims[0], dims[1]),),
30+
)
31+
x_op = XBlockwise(core_op, signature=core_op.gufunc_signature, core_dims=core_dims)
32+
return x_op(x)
33+
34+
35+
def solve(
36+
a,
37+
b,
38+
dims: Sequence[str],
39+
assume_a="gen",
40+
lower: bool = False,
41+
check_finite: bool = False,
42+
):
43+
a, b = as_xtensor(a), as_xtensor(b)
44+
if len(dims) == 2:
45+
b_ndim = 1
46+
[m1_dim] = [dim for dim in dims if dim not in b.type.dims]
47+
m2_dim = dims[0] if dims[0] != m1_dim else dims[1]
48+
input_core_dims = ((m1_dim, m2_dim), (m2_dim,))
49+
output_core_dims = ((m2_dim,),)
50+
elif len(dims) == 3:
51+
b_ndim = 2
52+
[n_dim] = [dim for dim in dims if dim not in a.type.dims]
53+
[m1_dim, m2_dim] = [dim for dim in dims if dim != n_dim]
54+
input_core_dims = ((m1_dim, m2_dim), (m2_dim, n_dim))
55+
output_core_dims = (
56+
(
57+
m2_dim,
58+
n_dim,
59+
),
60+
)
61+
else:
62+
raise ValueError("Solve dims must have length 2 or 3")
63+
64+
core_op = Solve(
65+
b_ndim=b_ndim, assume_a=assume_a, lower=lower, check_finite=check_finite
66+
)
67+
x_op = XBlockwise(
68+
core_op,
69+
signature=core_op.gufunc_signature,
70+
core_dims=(input_core_dims, output_core_dims),
71+
)
72+
return x_op(a, b)

pytensor/xtensor/math.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import inspect
2+
import sys
3+
4+
import pytensor.scalar as ps
5+
from pytensor.scalar import ScalarOp
6+
from pytensor.xtensor.vectorization import XElemwise
7+
8+
9+
this_module = sys.modules[__name__]
10+
11+
12+
def get_all_scalar_ops():
13+
"""
14+
Find all scalar operations in the pytensor.scalar module that can be wrapped with XElemwise.
15+
16+
Returns:
17+
dict: A dictionary mapping operation names to XElemwise instances
18+
"""
19+
result = {}
20+
21+
# Get all module members
22+
for name, obj in inspect.getmembers(ps):
23+
# Check if the object is a scalar op (has make_node method and is not an abstract class)
24+
if isinstance(obj, ScalarOp):
25+
result[name] = XElemwise(obj)
26+
27+
return result
28+
29+
30+
for name, op in get_all_scalar_ops().items():
31+
setattr(this_module, name, op)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
import pytensor.xtensor.rewriting.basic
22
import pytensor.xtensor.rewriting.shape
3+
import pytensor.xtensor.rewriting.vectorization
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pytensor.graph import node_rewriter
2+
from pytensor.tensor.blockwise import Blockwise
3+
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
5+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
6+
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
7+
8+
9+
@register_xcanonicalize
10+
@node_rewriter(tracks=[XElemwise])
11+
def lower_elemwise(fgraph, node):
12+
out_dims = node.outputs[0].type.dims
13+
14+
# Convert input XTensors to Tensors and align batch dimensions
15+
tensor_inputs = []
16+
for inp in node.inputs:
17+
inp_dims = inp.type.dims
18+
order = [
19+
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
20+
for out_dim in out_dims
21+
]
22+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
23+
tensor_inputs.append(tensor_inp)
24+
25+
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
26+
*tensor_inputs, return_list=True
27+
)
28+
29+
# Convert output Tensors to XTensors
30+
new_outs = [
31+
xtensor_from_tensor(tensor_out, dims=out_dims) for tensor_out in tensor_outs
32+
]
33+
return new_outs
34+
35+
36+
@register_xcanonicalize
37+
@node_rewriter(tracks=[XBlockwise])
38+
def lower_blockwise(fgraph, node):
39+
op: XBlockwise = node.op
40+
batch_ndim = node.outputs[0].type.ndim - len(op.outputs_sig[0])
41+
batch_dims = node.outputs[0].type.dims[:batch_ndim]
42+
43+
# Convert input Tensors to XTensors, align batch dimensions and place core dimension at the end
44+
tensor_inputs = []
45+
for inp, core_dims in zip(node.inputs, op.core_dims[0]):
46+
inp_dims = inp.type.dims
47+
# Align the batch dims of the input, and place the core dims on the right
48+
batch_order = [
49+
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
50+
for batch_dim in batch_dims
51+
]
52+
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
53+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
54+
tensor_inputs.append(tensor_inp)
55+
56+
tensor_op = Blockwise(core_op=node.op.core_op, signature=op.signature)
57+
tensor_outs = tensor_op(*tensor_inputs, return_list=True)
58+
59+
# Convert output Tensors to XTensors
60+
new_outs = [
61+
xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
62+
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
63+
]
64+
return new_outs

pytensor/xtensor/type.py

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,109 @@ def __complex__(self):
146146
"Call `.astype(complex)` for the symbolic equivalent."
147147
)
148148

149+
# Python valid overloads
150+
def __abs__(self):
151+
return px.math.abs(self)
152+
153+
def __neg__(self):
154+
return px.math.neg(self)
155+
156+
def __lt__(self, other):
157+
return px.math.lt(self, other)
158+
159+
def __le__(self, other):
160+
return px.math.le(self, other)
161+
162+
def __gt__(self, other):
163+
return px.math.gt(self, other)
164+
165+
def __ge__(self, other):
166+
return px.math.ge(self, other)
167+
168+
def __invert__(self):
169+
return px.math.invert(self)
170+
171+
def __and__(self, other):
172+
return px.math.and_(self, other)
173+
174+
def __or__(self, other):
175+
return px.math.or_(self, other)
176+
177+
def __xor__(self, other):
178+
return px.math.xor(self, other)
179+
180+
def __rand__(self, other):
181+
return px.math.and_(other, self)
182+
183+
def __ror__(self, other):
184+
return px.math.or_(other, self)
185+
186+
def __rxor__(self, other):
187+
return px.math.xor(other, self)
188+
189+
def __add__(self, other):
190+
return px.math.add(self, other)
191+
192+
def __sub__(self, other):
193+
return px.math.sub(self, other)
194+
195+
def __mul__(self, other):
196+
return px.math.mul(self, other)
197+
198+
def __div__(self, other):
199+
return px.math.div(self, other)
200+
201+
def __pow__(self, other):
202+
return px.math.pow(self, other)
203+
204+
def __mod__(self, other):
205+
return px.math.mod(self, other)
206+
207+
def __divmod__(self, other):
208+
return px.math.divmod(self, other)
209+
210+
def __truediv__(self, other):
211+
return px.math.true_div(self, other)
212+
213+
def __floordiv__(self, other):
214+
return px.math.floor_div(self, other)
215+
216+
def __rtruediv__(self, other):
217+
return px.math.true_div(other, self)
218+
219+
def __rfloordiv__(self, other):
220+
return px.math.floor_div(other, self)
221+
222+
def __radd__(self, other):
223+
return px.math.add(other, self)
224+
225+
def __rsub__(self, other):
226+
return px.math.sub(other, self)
227+
228+
def __rmul__(self, other):
229+
return px.math.mul(other, self)
230+
231+
def __rdiv__(self, other):
232+
return px.math.div_proxy(other, self)
233+
234+
def __rmod__(self, other):
235+
return px.math.mod(other, self)
236+
237+
def __rdivmod__(self, other):
238+
return px.math.divmod(other, self)
239+
240+
def __rpow__(self, other):
241+
return px.math.pow(other, self)
242+
243+
def __ceil__(self):
244+
return px.math.ceil(self)
245+
246+
def __floor__(self):
247+
return px.math.floor(self)
248+
249+
def __trunc__(self):
250+
return px.math.trunc(self)
251+
149252
# DataArray-like attributes
150253
# https://docs.xarray.dev/en/latest/api.html#id1
151254
@property
@@ -192,14 +295,33 @@ def dtype(self):
192295

193296
# DataArray contents
194297
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
195-
def rename(self, new_name_or_name_dict, **names):
298+
def rename(self, new_name_or_name_dict=None, **names):
196299
if isinstance(new_name_or_name_dict, str):
197-
# TODO: Should we make a symbolic copy?
198-
self.name = new_name_or_name_dict
300+
new_name = new_name_or_name_dict
199301
name_dict = None
200302
else:
303+
new_name = None
201304
name_dict = new_name_or_name_dict
202-
return px.basic.rename(name_dict, **names)
305+
new_out = px.basic.rename(self, name_dict, **names)
306+
new_out.name = new_name
307+
return new_out
308+
309+
# def swap_dims(self, *args, **kwargs):
310+
# ...
311+
#
312+
# def expand_dims(self, *args, **kwargs):
313+
# ...
314+
#
315+
# def squeeze(self):
316+
# ...
317+
318+
def copy(self, name: str | None = None):
319+
out = px.math.identity(self)
320+
out.name = name
321+
return out
322+
323+
def astype(self, dtype):
324+
return px.math.cast(self, dtype)
203325

204326
def item(self):
205327
raise NotImplementedError("item not implemented for XTensorVariable")
@@ -219,6 +341,26 @@ def sel(self, *args, **kwargs):
219341
def __getitem__(self, idx):
220342
raise NotImplementedError("Indexing not yet implemnented")
221343

344+
# ndarray methods
345+
# https://docs.xarray.dev/en/latest/api.html#id7
346+
def clip(self, min, max):
347+
return px.math.clip(self, min, max)
348+
349+
def conj(self):
350+
return px.math.conj(self)
351+
352+
@property
353+
def imag(self):
354+
return px.math.imag(self)
355+
356+
@property
357+
def real(self):
358+
return px.math.real(self)
359+
360+
# @property
361+
# def T(self):
362+
# ...
363+
222364

223365
class XTensorConstantSignature(tuple):
224366
def __eq__(self, other):

0 commit comments

Comments
 (0)