Skip to content

Commit 3c53271

Browse files
committed
Implement reduction operations for XTensorVariables
1 parent c3da710 commit 3c53271

File tree

8 files changed

+262
-21
lines changed

8 files changed

+262
-21
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -473,24 +473,6 @@ def cumprod(x, axis=None):
473473
return CumOp(axis=axis, mode="mul")(x)
474474

475475

476-
class CumsumOp(Op):
477-
__props__ = ("axis",)
478-
479-
def __new__(typ, *args, **kwargs):
480-
obj = object.__new__(CumOp, *args, **kwargs)
481-
obj.mode = "add"
482-
return obj
483-
484-
485-
class CumprodOp(Op):
486-
__props__ = ("axis",)
487-
488-
def __new__(typ, *args, **kwargs):
489-
obj = object.__new__(CumOp, *args, **kwargs)
490-
obj.mode = "mul"
491-
return obj
492-
493-
494476
def diff(x, n=1, axis=-1):
495477
"""Calculate the `n`-th order discrete difference along the given `axis`.
496478

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import (
55
linalg,
6+
special,
67
)
78
from pytensor.xtensor.type import (
89
XTensorType,

pytensor/xtensor/reduction.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from collections.abc import Sequence
2+
from functools import partial
3+
from types import EllipsisType
4+
5+
import pytensor.scalar as ps
6+
from pytensor.graph.basic import Apply, Variable
7+
from pytensor.tensor.math import variadic_mul
8+
from pytensor.xtensor.basic import XOp
9+
from pytensor.xtensor.math import neq, sqrt
10+
from pytensor.xtensor.math import sqr as square
11+
from pytensor.xtensor.type import as_xtensor, xtensor
12+
13+
14+
REDUCE_DIM = str | Sequence[str] | EllipsisType | None
15+
16+
17+
class XReduce(XOp):
18+
__slots__ = ("binary_op", "dims")
19+
20+
def __init__(self, binary_op, dims: Sequence[str]):
21+
super().__init__()
22+
self.binary_op = binary_op
23+
# Order of reduce dims doens't change the behavior of the Op
24+
self.dims = tuple(sorted(dims))
25+
26+
def make_node(self, x: Variable) -> Apply:
27+
x = as_xtensor(x)
28+
x_dims = x.type.dims
29+
x_dims_set = set(x_dims)
30+
reduce_dims_set = set(self.dims)
31+
if x_dims_set == reduce_dims_set:
32+
out_dims, out_shape = [], []
33+
else:
34+
if not reduce_dims_set.issubset(x_dims_set):
35+
raise ValueError(
36+
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
37+
)
38+
out_dims, out_shape = zip(
39+
*[
40+
(d, s)
41+
for d, s in zip(x_dims, x.type.shape)
42+
if d not in reduce_dims_set
43+
]
44+
)
45+
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
46+
return Apply(self, [x], [output])
47+
48+
49+
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
50+
if isinstance(dim, str):
51+
return (dim,)
52+
elif dim is None or dim is Ellipsis:
53+
x = as_xtensor(x)
54+
return x.type.dims
55+
return dim
56+
57+
58+
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
59+
dims = _process_user_dims(x, dim)
60+
return XReduce(binary_op=binary_op, dims=dims)(x)
61+
62+
63+
sum = partial(reduce, binary_op=ps.add)
64+
prod = partial(reduce, binary_op=ps.mul)
65+
max = partial(reduce, binary_op=ps.scalar_maximum)
66+
min = partial(reduce, binary_op=ps.scalar_minimum)
67+
68+
69+
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
70+
x = as_xtensor(x)
71+
if x.type.dtype != "bool":
72+
x = neq(x, 0)
73+
return reduce(x, dim=dim, binary_op=binary_op)
74+
75+
76+
all = partial(bool_reduce, binary_op=ps.and_)
77+
any = partial(bool_reduce, binary_op=ps.or_)
78+
79+
80+
def _infer_reduced_size(original_var, reduced_var):
81+
reduced_dims = reduced_var.dims
82+
return variadic_mul(
83+
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
84+
)
85+
86+
87+
def mean(x, dim: REDUCE_DIM):
88+
x = as_xtensor(x)
89+
sum_x = sum(x, dim)
90+
n = _infer_reduced_size(x, sum_x)
91+
return sum_x / n
92+
93+
94+
def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
95+
x = as_xtensor(x)
96+
x_mean = mean(x, dim)
97+
n = _infer_reduced_size(x, x_mean)
98+
return square(x - x_mean) / (n - ddof)
99+
100+
101+
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
102+
return sqrt(var(x, dim, ddof=ddof))
103+
104+
105+
class XCumReduce(XOp):
106+
__props__ = ("binary_op", "dims")
107+
108+
def __init__(self, binary_op, dims: Sequence[str]):
109+
self.binary_op = binary_op
110+
self.dims = tuple(sorted(dims)) # Order doesn't matter
111+
112+
def make_node(self, x: Variable) -> Apply:
113+
x = as_xtensor(x)
114+
out = x.type()
115+
return Apply(self, [x], [out])
116+
117+
118+
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
119+
dims = _process_user_dims(x, dim)
120+
return XCumReduce(dims=dims, binary_op=binary_op)(x)
121+
122+
123+
cumsum = partial(cumreduce, binary_op=ps.add)
124+
cumprod = partial(cumreduce, binary_op=ps.mul)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
import pytensor.xtensor.rewriting.basic
2+
import pytensor.xtensor.rewriting.reduction
23
import pytensor.xtensor.rewriting.shape
34
import pytensor.xtensor.rewriting.vectorization
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from functools import partial
2+
3+
import pytensor.scalar as ps
4+
from pytensor.graph.rewriting.basic import node_rewriter
5+
from pytensor.tensor.extra_ops import CumOp
6+
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum
7+
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
8+
from pytensor.xtensor.reduction import XCumReduce, XReduce
9+
from pytensor.xtensor.rewriting.utils import register_xcanonicalize
10+
11+
12+
@register_xcanonicalize
13+
@node_rewriter(tracks=[XReduce])
14+
def lower_reduce(fgraph, node):
15+
[x] = node.inputs
16+
[out] = node.outputs
17+
x_dims = x.type.dims
18+
reduce_dims = node.op.dims
19+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
20+
21+
if not reduce_axis:
22+
return [x]
23+
24+
match node.op.binary_op:
25+
case ps.add:
26+
tensor_op_class = Sum
27+
case ps.mul:
28+
tensor_op_class = Prod
29+
case ps.and_:
30+
tensor_op_class = All
31+
case ps.or_:
32+
tensor_op_class = Any
33+
case ps.scalar_maximum:
34+
tensor_op_class = Max
35+
case ps.scalar_minimum:
36+
tensor_op_class = Min
37+
case _:
38+
# Case without known/predefined Ops
39+
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op)
40+
41+
x_tensor = tensor_from_xtensor(x)
42+
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
43+
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
44+
return [new_out]
45+
46+
47+
@register_xcanonicalize
48+
@node_rewriter(tracks=[XCumReduce])
49+
def lower_cumreduce(fgraph, node):
50+
[x] = node.inputs
51+
x_dims = x.type.dims
52+
reduce_dims = node.op.dims
53+
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
54+
55+
if not reduce_axis:
56+
return [x]
57+
58+
match node.op.binary_op:
59+
case ps.add:
60+
tensor_op_class = partial(CumOp, mode="add")
61+
case ps.mul:
62+
tensor_op_class = partial(CumOp, mode="mul")
63+
case _:
64+
# We don't know how to convert an arbitrary binary cum/reduce Op
65+
return None
66+
67+
# Each dim corresponds to an application of Cumsum/Cumprod
68+
out_tensor = tensor_from_xtensor(x)
69+
for axis in reduce_axis:
70+
out_tensor = tensor_op_class(axis=axis)(out_tensor)
71+
out = xtensor_from_tensor(out_tensor, x.type.dims)
72+
return [out]

pytensor/xtensor/special.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pytensor.xtensor.math import exp
2+
from pytensor.xtensor.reduction import REDUCE_DIM
3+
4+
5+
def softmax(x, dim: REDUCE_DIM = None):
6+
exp_x = exp(x)
7+
return exp_x / exp_x.sum(dim=dim)

pytensor/xtensor/type.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,40 @@ def imag(self):
357357
def real(self):
358358
return px.math.real(self)
359359

360-
# @property
361-
# def T(self):
362-
# ...
360+
# Aggregation
361+
# https://docs.xarray.dev/en/latest/api.html#id6
362+
def all(self, dim):
363+
return px.reduction.all(self, dim)
364+
365+
def any(self, dim):
366+
return px.reduction.any(self, dim)
367+
368+
def max(self, dim):
369+
return px.reduction.max(self, dim)
370+
371+
def min(self, dim):
372+
return px.reduction.min(self, dim)
373+
374+
def mean(self, dim):
375+
return px.reduction.mean(self, dim)
376+
377+
def prod(self, dim):
378+
return px.reduction.prod(self, dim)
379+
380+
def sum(self, dim):
381+
return px.reduction.sum(self, dim)
382+
383+
def std(self, dim):
384+
return px.reduction.std(self, dim)
385+
386+
def var(self, dim):
387+
return px.reduction.var(self, dim)
388+
389+
def cumsum(self, dim):
390+
return px.reduction.cumsum(self, dim)
391+
392+
def cumprod(self, dim):
393+
return px.reduction.cumprod(self, dim)
363394

364395

365396
class XTensorConstantSignature(tuple):

tests/xtensor/test_reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
3+
from pytensor.xtensor.type import xtensor
4+
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
5+
6+
7+
@pytest.mark.parametrize(
8+
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
9+
)
10+
@pytest.mark.parametrize(
11+
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
12+
)
13+
def test_reduction(method, dim):
14+
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
15+
out = getattr(x, method)(dim=dim)
16+
17+
fn = xr_function([x], out)
18+
x_test = xr_arange_like(x)
19+
20+
xr_assert_allclose(
21+
fn(x_test),
22+
getattr(x_test, method)(dim=dim),
23+
)

0 commit comments

Comments
 (0)