Skip to content

Commit 717ba1a

Browse files
committed
Add rewrite for Sum(MakeVector)
1 parent e8bd0d7 commit 717ba1a

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4444
from pytensor.tensor.exceptions import NotScalarConstantError
4545
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
46+
from pytensor.tensor.math import Sum, add
4647
from pytensor.tensor.math import all as at_all
4748
from pytensor.tensor.math import eq
4849
from pytensor.tensor.shape import Shape_i
@@ -956,6 +957,30 @@ def local_join_make_vector(fgraph, node):
956957
return [ret]
957958

958959

960+
@register_specialize
961+
@register_canonicalize
962+
@register_useless
963+
@node_rewriter([Sum])
964+
def local_sum_make_vector(fgraph, node):
965+
"""A sum of a MakeVector node is just the sum of the elements."""
966+
(array,) = node.inputs
967+
968+
if array.owner is None:
969+
return
970+
971+
if not isinstance(array.owner.op, MakeVector):
972+
return
973+
974+
if node.op.axis not in [None, 0, -1]:
975+
return
976+
977+
elements = array.owner.inputs
978+
dtype = node.op.acc_dtype
979+
element_sum = add(*[cast(value, dtype) for value in elements])
980+
981+
return [as_tensor_variable(element_sum)]
982+
983+
959984
@register_useless("local_remove_switch_const_cond")
960985
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
961986
@register_specialize

tests/tensor/rewriting/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.graph.rewriting.utils import rewrite_graph
2020
from pytensor.printing import debugprint, pprint
2121
from pytensor.raise_op import Assert, CheckAndRaise
22+
from pytensor.scalar.basic import Add
2223
from pytensor.tensor.basic import (
2324
Alloc,
2425
Join,
@@ -102,6 +103,7 @@
102103
values_eq_approx_remove_nan,
103104
vector,
104105
)
106+
from pytensor.tensor.var import TensorVariable
105107
from tests import unittest_tools as utt
106108

107109

@@ -1300,6 +1302,20 @@ def test_local_join_make_vector():
13001302
assert check_stack_trace(f, ops_to_check="all")
13011303

13021304

1305+
def test_local_sum_make_vector():
1306+
a, b, c = scalars("abc")
1307+
mv = MakeVector(config.floatX)
1308+
output = mv(a, b, c).sum()
1309+
1310+
func = function([a, b, c], output)
1311+
1312+
elemwise = func.maker.fgraph.outputs[0].owner
1313+
# The MakeVector op should be optimized away, so we just
1314+
# take the sum of the scalars.
1315+
assert elemwise.inputs[0].name == "a"
1316+
assert isinstance(elemwise.inputs[0], TensorVariable)
1317+
1318+
13031319
@pytest.mark.parametrize(
13041320
"dtype",
13051321
[

0 commit comments

Comments
 (0)