Skip to content

Commit 597e711

Browse files
committed
Implement reduction operations for XTensorVariables
1 parent 9cc9f9b commit 597e711

File tree

8 files changed

+261
-21
lines changed

8 files changed

+261
-21
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -473,24 +473,6 @@ def cumprod(x, axis=None):
473473
return CumOp(axis=axis, mode="mul")(x)
474474

475475

476-
class CumsumOp(Op):
477-
__props__ = ("axis",)
478-
479-
def __new__(typ, *args, **kwargs):
480-
obj = object.__new__(CumOp, *args, **kwargs)
481-
obj.mode = "add"
482-
return obj
483-
484-
485-
class CumprodOp(Op):
486-
__props__ = ("axis",)
487-
488-
def __new__(typ, *args, **kwargs):
489-
obj = object.__new__(CumOp, *args, **kwargs)
490-
obj.mode = "mul"
491-
return obj
492-
493-
494476
def diff(x, n=1, axis=-1):
495477
"""Calculate the `n`-th order discrete difference along the given `axis`.
496478

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import (
55
linalg,
6+
special,
67
)
78
from pytensor.xtensor.type import (
89
XTensorType,

pytensor/xtensor/reduction.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from collections.abc import Sequence
2+
from functools import partial
3+
from types import EllipsisType
4+
5+
import pytensor.scalar as ps
6+
from pytensor.graph.basic import Apply, Variable
7+
from pytensor.tensor.math import variadic_mul
8+
from pytensor.xtensor.basic import XOp
9+
from pytensor.xtensor.math import neq, sqrt
10+
from pytensor.xtensor.math import sqr as square
11+
from pytensor.xtensor.type import as_xtensor, xtensor
12+
13+
14+
REDUCE_DIM = str | Sequence[str] | EllipsisType | None
15+
16+
17+
class XReduce(XOp):
18+
__slots__ = ("binary_op", "dims")
19+
20+
def __init__(self, binary_op, dims: Sequence[str]):
21+
self.binary_op = binary_op
22+
# Order of reduce dims doens't change the behavior of the Op
23+
self.dims = tuple(sorted(dims))
24+
25+
def make_node(self, x: Variable) -> Apply:
26+
x = as_xtensor(x)
27+
x_dims = x.type.dims
28+
x_dims_set = set(x_dims)
29+
reduce_dims_set = set(self.dims)
30+
if x_dims_set == reduce_dims_set:
31+
out_dims, out_shape = [], []
32+
else:
33+
if not reduce_dims_set.issubset(x_dims_set):
34+
raise ValueError(
35+
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
36+
)
37+
out_dims, out_shape = zip(
38+
*[
39+
(d, s)
40+
for d, s in zip(x_dims, x.type.shape)
41+
if d not in reduce_dims_set
42+
]
43+
)
44+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
45+
return Apply(self, [x], [output])
46+
47+
48+
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
49+
if isinstance(dim, str):
50+
return (dim,)
51+
elif dim is None or dim is Ellipsis:
52+
x = as_xtensor(x)
53+
return x.type.dims
54+
return dim
55+
56+
57+
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
58+
dims = _process_user_dims(x, dim)
59+
return XReduce(binary_op=binary_op, dims=dims)(x)
60+
61+
62+
sum = partial(reduce, binary_op=ps.add)
63+
prod = partial(reduce, binary_op=ps.mul)
64+
max = partial(reduce, binary_op=ps.scalar_maximum)
65+
min = partial(reduce, binary_op=ps.scalar_minimum)
66+
67+
68+
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
69+
x = as_xtensor(x)
70+
if x.type.dtype != "bool":
71+
x = neq(x, 0)
72+
return reduce(x, dim=dim, binary_op=binary_op)
73+
74+
75+
all = partial(bool_reduce, binary_op=ps.and_)
76+
any = partial(bool_reduce, binary_op=ps.or_)
77+
78+
79+
def _infer_reduced_size(original_var, reduced_var):
80+
reduced_dims = reduced_var.dims
81+
return variadic_mul(
82+
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
83+
)
84+
85+
86+
def mean(x, dim: REDUCE_DIM):
87+
x = as_xtensor(x)
88+
sum_x = sum(x, dim)
89+
n = _infer_reduced_size(x, sum_x)
90+
return sum_x / n
91+
92+
93+
def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
94+
x = as_xtensor(x)
95+
x_mean = mean(x, dim)
96+
n = _infer_reduced_size(x, x_mean)
97+
return square(x - x_mean) / (n - ddof)
98+
99+
100+
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
101+
return sqrt(var(x, dim, ddof=ddof))
102+
103+
104+
class XCumReduce(XOp):
105+
__props__ = ("binary_op", "dims")
106+
107+
def __init__(self, binary_op, dims: Sequence[str]):
108+
self.binary_op = binary_op
109+
self.dims = tuple(sorted(dims)) # Order doesn't matter
110+
111+
def make_node(self, x: Variable) -> Apply:
112+
x = as_xtensor(x)
113+
out = x.type()
114+
return Apply(self, [x], [out])
115+
116+
117+
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
118+
dims = _process_user_dims(x, dim)
119+
return XCumReduce(dims=dims, binary_op=binary_op)(x)
120+
121+
122+
cumsum = partial(cumreduce, binary_op=ps.add)
123+
cumprod = partial(cumreduce, binary_op=ps.mul)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.reduction
23
import pytensor.xtensor.rewriting.shape
34
import pytensor.xtensor.rewriting.vectorization
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from functools import partial
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph.rewriting.basic import node_rewriter
5+
from pytensor.tensor.extra_ops import CumOp
6+
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum
7+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
8+
from pytensor.xtensor.reduction import XCumReduce, XReduce
9+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
10+
11+
12+
@register_xcanonicalize
13+
@node_rewriter(tracks=[XReduce])
14+
def lower_reduce(fgraph, node):
15+
[x] = node.inputs
16+
[out] = node.outputs
17+
x_dims = x.type.dims
18+
reduce_dims = node.op.dims
19+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
20+
21+
if not reduce_axis:
22+
return [x]
23+
24+
match node.op.binary_op:
25+
case ps.add:
26+
tensor_op_class = Sum
27+
case ps.mul:
28+
tensor_op_class = Prod
29+
case ps.and_:
30+
tensor_op_class = All
31+
case ps.or_:
32+
tensor_op_class = Any
33+
case ps.scalar_maximum:
34+
tensor_op_class = Max
35+
case ps.scalar_minimum:
36+
tensor_op_class = Min
37+
case _:
38+
# Case without known/predefined Ops
39+
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op)
40+
41+
x_tensor = tensor_from_xtensor(x)
42+
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
43+
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
44+
return [new_out]
45+
46+
47+
@register_xcanonicalize
48+
@node_rewriter(tracks=[XCumReduce])
49+
def lower_cumreduce(fgraph, node):
50+
[x] = node.inputs
51+
x_dims = x.type.dims
52+
reduce_dims = node.op.dims
53+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
54+
55+
if not reduce_axis:
56+
return [x]
57+
58+
match node.op.binary_op:
59+
case ps.add:
60+
tensor_op_class = partial(CumOp, mode="add")
61+
case ps.mul:
62+
tensor_op_class = partial(CumOp, mode="mul")
63+
case _:
64+
# We don't know how to convert an arbitrary binary cum/reduce Op
65+
return None
66+
67+
# Each dim corresponds to an application of Cumsum/Cumprod
68+
out_tensor = tensor_from_xtensor(x)
69+
for axis in reduce_axis:
70+
out_tensor = tensor_op_class(axis=axis)(out_tensor)
71+
out = xtensor_from_tensor(out_tensor, x.type.dims)
72+
return [out]

pytensor/xtensor/special.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pytensor.xtensor.math import exp
2+
from pytensor.xtensor.reduction import REDUCE_DIM
3+
4+
5+
def softmax(x, dim: REDUCE_DIM = None):
6+
exp_x = exp(x)
7+
return exp_x / exp_x.sum(dim=dim)

pytensor/xtensor/type.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,40 @@ def imag(self):
358358
def real(self):
359359
return px.math.real(self)
360360

361-
# @property
362-
# def T(self):
363-
# ...
361+
# Aggregation
362+
# https://docs.xarray.dev/en/latest/api.html#id6
363+
def all(self, dim):
364+
return px.reduction.all(self, dim)
365+
366+
def any(self, dim):
367+
return px.reduction.any(self, dim)
368+
369+
def max(self, dim):
370+
return px.reduction.max(self, dim)
371+
372+
def min(self, dim):
373+
return px.reduction.min(self, dim)
374+
375+
def mean(self, dim):
376+
return px.reduction.mean(self, dim)
377+
378+
def prod(self, dim):
379+
return px.reduction.prod(self, dim)
380+
381+
def sum(self, dim):
382+
return px.reduction.sum(self, dim)
383+
384+
def std(self, dim):
385+
return px.reduction.std(self, dim)
386+
387+
def var(self, dim):
388+
return px.reduction.var(self, dim)
389+
390+
def cumsum(self, dim):
391+
return px.reduction.cumsum(self, dim)
392+
393+
def cumprod(self, dim):
394+
return px.reduction.cumprod(self, dim)
364395

365396

366397
class XTensorConstantSignature(tuple):

tests/xtensor/test_reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from pytensor.xtensor.type import xtensor
4+
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
5+
6+
7+
@pytest.mark.parametrize(
8+
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
9+
)
10+
@pytest.mark.parametrize(
11+
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
12+
)
13+
def test_reduction(method, dim):
14+
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
15+
out = getattr(x, method)(dim=dim)
16+
17+
fn = xr_function([x], out)
18+
x_test = xr_arange_like(x)
19+
20+
xr_assert_allclose(
21+
fn(x_test),
22+
getattr(x_test, method)(dim=dim),
23+
)

0 commit comments

Comments
 (0)