Skip to content

Commit 4ee3588

Browse files
authored
Prevent local_sum_make_vector from introducing internal float64 (#659)
1 parent d175203 commit 4ee3588

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929

3030
import pytensor.scalar.basic as ps
31-
from pytensor import compile
31+
from pytensor import compile, config
3232
from pytensor.compile.ops import ViewOp
3333
from pytensor.graph import FunctionGraph
3434
from pytensor.graph.basic import Constant, Variable
@@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node):
941941
elements = array.owner.inputs
942942
acc_dtype = node.op.acc_dtype
943943
out_dtype = node.op.dtype
944+
945+
# Skip rewrite if it would add unnecessary float64 to the graph
946+
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
947+
return
948+
944949
if len(elements) == 0:
945950
element_sum = zeros(dtype=out_dtype, shape=())
946951
elif len(elements) == 1:

tests/tensor/rewriting/test_basic.py

Lines changed: 32 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.compile.mode import get_default_mode, get_mode
1313
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1414
from pytensor.configdefaults import config
15-
from pytensor.graph.basic import equal_computations, vars_between
15+
from pytensor.graph.basic import equal_computations
1616
from pytensor.graph.fg import FunctionGraph
1717
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1818
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -26,12 +26,12 @@
2626
ScalarFromTensor,
2727
Split,
2828
TensorFromScalar,
29+
cast,
2930
join,
3031
tile,
3132
)
3233
from pytensor.tensor.elemwise import DimShuffle, Elemwise
3334
from pytensor.tensor.math import (
34-
Sum,
3535
add,
3636
bitwise_and,
3737
bitwise_or,
@@ -1298,41 +1298,48 @@ def test_local_join_make_vector():
12981298

12991299

13001300
def test_local_sum_make_vector():
1301+
# To check that rewrite is applied, we must enforce dtype to
1302+
# allow rewrite to occur even if floatX != "float64"
13011303
a, b, c = scalars("abc")
13021304
mv = MakeVector(config.floatX)
1303-
output = mv(a, b, c).sum()
1304-
1305-
output = rewrite_graph(output)
1306-
between = vars_between([a, b, c], [output])
1307-
for var in between:
1308-
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
1305+
output = mv(a, b, c).sum(dtype="float64")
1306+
rewrite_output = rewrite_graph(output)
1307+
expected_output = cast(
1308+
add(*[cast(value, "float64") for value in [a, b, c]]), dtype="float64"
1309+
)
1310+
assert equal_computations([expected_output], [rewrite_output])
13091311

1310-
# Check for empty sum
1312+
# Empty axes should return input vector since no sum is applied
13111313
a, b, c = scalars("abc")
13121314
mv = MakeVector(config.floatX)
13131315
output = mv(a, b, c).sum(axis=[])
1316+
rewrite_output = rewrite_graph(output)
1317+
expected_output = mv(a, b, c)
1318+
assert equal_computations([expected_output], [rewrite_output])
13141319

1315-
output = rewrite_graph(output)
1316-
between = vars_between([a, b, c], [output])
1317-
for var in between:
1318-
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1319-
1320-
# Check empty MakeVector
1320+
# Empty input should return 0
13211321
mv = MakeVector(config.floatX)
13221322
output = mv().sum()
1323+
rewrite_output = rewrite_graph(output)
1324+
expected_output = pt.as_tensor(0, dtype=config.floatX)
1325+
assert equal_computations([expected_output], [rewrite_output])
13231326

1324-
output = rewrite_graph(output)
1325-
between = vars_between([a, b, c], [output])
1326-
for var in between:
1327-
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1328-
1327+
# Single element input should return element value
1328+
a = scalars("a")
13291329
mv = MakeVector(config.floatX)
13301330
output = mv(a).sum()
1331-
1332-
output = rewrite_graph(output)
1333-
between = vars_between([a, b, c], [output])
1334-
for var in between:
1335-
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1331+
rewrite_output = rewrite_graph(output)
1332+
expected_output = cast(a, config.floatX)
1333+
assert equal_computations([expected_output], [rewrite_output])
1334+
1335+
# This is a regression test for #653. Ensure that rewrite is NOT
1336+
# applied when user requests float32
1337+
with config.change_flags(floatX="float32", warn_float64="raise"):
1338+
a, b, c = scalars("abc")
1339+
mv = MakeVector(config.floatX)
1340+
output = mv(a, b, c).sum()
1341+
rewrite_output = rewrite_graph(output)
1342+
assert equal_computations([output], [rewrite_output])
13361343

13371344

13381345
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)