Skip to content

Commit f449291

Browse files
ricardoV94Ch0ronomato
authored andcommitted
Generalize and rename local_reduce_chain
1 parent 31d34d0 commit f449291

File tree

2 files changed

+205
-182
lines changed

2 files changed

+205
-182
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,11 @@
100100
values_eq_approx_remove_inf_nan,
101101
values_eq_approx_remove_nan,
102102
)
103-
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
103+
from pytensor.tensor.variable import (
104+
TensorConstant,
105+
TensorVariable,
106+
get_unique_constant_value,
107+
)
104108

105109

106110
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
@@ -1575,42 +1579,48 @@ def local_sum_prod_all_to_none(fgraph, node):
15751579

15761580

15771581
@register_canonicalize
1578-
@node_rewriter([Sum, Prod])
1579-
def local_op_of_op(fgraph, node):
1582+
@node_rewriter([CAReduce])
1583+
def local_reduce_chain(fgraph, node) -> list[TensorVariable] | None:
15801584
"""
1581-
Prod(Prod()) -> single Prod()
1582-
or
15831585
Sum(Sum()) -> single Sum()
1586+
or any CAReduce(Careduce(x)) of the same type
15841587
15851588
"""
1586-
op_type = Sum if isinstance(node.op, Sum) else Prod
1587-
(node_inps,) = node.inputs
1588-
out_dtype = node.op.dtype
1589-
# This is done to make sure the rewrite doesn't affect other
1590-
# computations.
1591-
if len(fgraph.clients[node_inps]) == 1:
1592-
if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)):
1593-
# check to see either the inner or outer prod is doing a
1594-
# product over all axis, in which case we can remove it
1595-
if node_inps.owner.op.axis is None or node.op.axis is None:
1596-
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
1597-
1598-
# figure out which axes were in the original sum
1599-
newaxis = list(node_inps.owner.op.axis)
1600-
for i in node.op.axis:
1601-
new_i = i
1602-
for ii in node_inps.owner.op.axis:
1603-
if new_i >= ii:
1604-
new_i += 1
1605-
assert new_i not in newaxis
1606-
newaxis.append(new_i)
1607-
1608-
assert len(newaxis) == len(
1609-
list(node_inps.owner.op.axis) + list(node.op.axis)
1610-
)
1589+
[inner_reduce] = node.inputs
1590+
if not (inner_reduce.owner and isinstance(inner_reduce.owner.op, CAReduce)):
1591+
return None
1592+
1593+
# Don't apply rewrite if inner_reduce is used elsewhere
1594+
if len(fgraph.clients[inner_reduce]) > 1:
1595+
return None
1596+
1597+
# Check if CAReduces have the same scalar op
1598+
outer_op: CAReduce = node.op
1599+
inner_op = inner_reduce.owner.op
1600+
1601+
if outer_op.scalar_op != inner_op.scalar_op:
1602+
return None
16111603

1612-
combined = op_type(newaxis, dtype=out_dtype)
1613-
return [combined(node_inps.owner.inputs[0])]
1604+
outer_axis = outer_op.axis
1605+
inner_axis = inner_op.axis
1606+
[x] = inner_reduce.owner.inputs
1607+
# check to see either the inner or outer prod is doing a
1608+
# product over all axis, in which case we can remove it
1609+
if outer_axis is None or inner_axis is None:
1610+
return [outer_op.clone(axis=None)(x)]
1611+
1612+
# Merge axis
1613+
newaxis = list(inner_axis)
1614+
for i in outer_axis:
1615+
new_i = i
1616+
for ii in inner_axis:
1617+
if new_i >= ii:
1618+
new_i += 1
1619+
assert new_i not in newaxis
1620+
newaxis.append(new_i)
1621+
1622+
assert len(newaxis) == len(inner_axis) + len(outer_axis)
1623+
return [outer_op.clone(axis=sorted(newaxis))(x)]
16141624

16151625

16161626
@register_canonicalize

0 commit comments

Comments
 (0)