Skip to content

Add rewrite for Sum(MakeVector) #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to
from pytensor.tensor.math import Sum, add
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i
Expand Down Expand Up @@ -956,6 +957,41 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_specialize
@register_canonicalize
@register_useless
@node_rewriter([Sum])
def local_sum_make_vector(fgraph, node):
"""A sum of a MakeVector node is just the sum of the elements."""
(array,) = node.inputs

if array.owner is None:
return

if not isinstance(array.owner.op, MakeVector):
return

if node.op.axis == ():
return [array]

# If this is not the case the sum is invalid
assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,)

elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype
if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
element_sum = cast(elements[0], out_dtype)
else:
element_sum = cast(
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
)

return [element_sum]


@register_useless("local_remove_switch_const_cond")
@register_canonicalize("fast_compile", "local_remove_switch_const_cond")
@register_specialize
Expand Down
41 changes: 40 additions & 1 deletion tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import equal_computations, vars_between
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
Expand All @@ -31,6 +31,7 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import (
Sum,
add,
bitwise_and,
bitwise_or,
Expand Down Expand Up @@ -1300,6 +1301,44 @@ def test_local_join_make_vector():
assert check_stack_trace(f, ops_to_check="all")


def test_local_sum_make_vector():
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))

# Check for empty sum
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum(axis=[])

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))

# Check empty MakeVector
mv = MakeVector(config.floatX)
output = mv().sum()

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))

mv = MakeVector(config.floatX)
output = mv(a).sum()

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))


@pytest.mark.parametrize(
"dtype",
[
Expand Down