diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 58a5918c12..ff103e9fc1 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -43,6 +43,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to +from pytensor.tensor.math import Sum, add from pytensor.tensor.math import all as at_all from pytensor.tensor.math import eq from pytensor.tensor.shape import Shape_i @@ -956,6 +957,41 @@ def local_join_make_vector(fgraph, node): return [ret] +@register_specialize +@register_canonicalize +@register_useless +@node_rewriter([Sum]) +def local_sum_make_vector(fgraph, node): + """A sum of a MakeVector node is just the sum of the elements.""" + (array,) = node.inputs + + if array.owner is None: + return + + if not isinstance(array.owner.op, MakeVector): + return + + if node.op.axis == (): + return [array] + + # If this is not the case the sum is invalid + assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,) + + elements = array.owner.inputs + acc_dtype = node.op.acc_dtype + out_dtype = node.op.dtype + if len(elements) == 0: + element_sum = zeros(dtype=out_dtype, shape=()) + elif len(elements) == 1: + element_sum = cast(elements[0], out_dtype) + else: + element_sum = cast( + add(*[cast(value, acc_dtype) for value in elements]), out_dtype + ) + + return [element_sum] + + @register_useless("local_remove_switch_const_cond") @register_canonicalize("fast_compile", "local_remove_switch_const_cond") @register_specialize diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index fe2b795907..3c3f917bc9 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -12,7 +12,7 @@ from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import equal_computations, vars_between from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -31,6 +31,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import ( + Sum, add, bitwise_and, bitwise_or, @@ -1300,6 +1301,44 @@ def test_local_join_make_vector(): assert check_stack_trace(f, ops_to_check="all") +def test_local_sum_make_vector(): + a, b, c = scalars("abc") + mv = MakeVector(config.floatX) + output = mv(a, b, c).sum() + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) + + # Check for empty sum + a, b, c = scalars("abc") + mv = MakeVector(config.floatX) + output = mv(a, b, c).sum(axis=[]) + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + + # Check empty MakeVector + mv = MakeVector(config.floatX) + output = mv().sum() + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + + mv = MakeVector(config.floatX) + output = mv(a).sum() + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + + @pytest.mark.parametrize( "dtype", [