|
12 | 12 | from pytensor.compile.mode import get_default_mode, get_mode
|
13 | 13 | from pytensor.compile.ops import DeepCopyOp, deep_copy_op
|
14 | 14 | from pytensor.configdefaults import config
|
15 |
| -from pytensor.graph.basic import equal_computations, vars_between |
| 15 | +from pytensor.graph.basic import equal_computations |
16 | 16 | from pytensor.graph.fg import FunctionGraph
|
17 | 17 | from pytensor.graph.rewriting.basic import check_stack_trace, out2in
|
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery
|
|
26 | 26 | ScalarFromTensor,
|
27 | 27 | Split,
|
28 | 28 | TensorFromScalar,
|
| 29 | + cast, |
29 | 30 | join,
|
30 | 31 | tile,
|
31 | 32 | )
|
32 | 33 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
33 | 34 | from pytensor.tensor.math import (
|
34 |
| - Sum, |
35 | 35 | add,
|
36 | 36 | bitwise_and,
|
37 | 37 | bitwise_or,
|
@@ -1298,41 +1298,48 @@ def test_local_join_make_vector():
|
1298 | 1298 |
|
1299 | 1299 |
|
1300 | 1300 | def test_local_sum_make_vector():
|
| 1301 | + # To check that rewrite is applied, we must enforce dtype to |
| 1302 | + # allow rewrite to occur even if floatX != "float64" |
1301 | 1303 | a, b, c = scalars("abc")
|
1302 | 1304 | mv = MakeVector(config.floatX)
|
1303 |
| - output = mv(a, b, c).sum() |
1304 |
| - |
1305 |
| - output = rewrite_graph(output) |
1306 |
| - between = vars_between([a, b, c], [output]) |
1307 |
| - for var in between: |
1308 |
| - assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) |
| 1305 | + output = mv(a, b, c).sum(dtype="float64") |
| 1306 | + rewrite_output = rewrite_graph(output) |
| 1307 | + expected_output = cast( |
| 1308 | + add(*[cast(value, "float64") for value in [a, b, c]]), dtype="float64" |
| 1309 | + ) |
| 1310 | + assert equal_computations([expected_output], [rewrite_output]) |
1309 | 1311 |
|
1310 |
| - # Check for empty sum |
| 1312 | + # Empty axes should return input vector since no sum is applied |
1311 | 1313 | a, b, c = scalars("abc")
|
1312 | 1314 | mv = MakeVector(config.floatX)
|
1313 | 1315 | output = mv(a, b, c).sum(axis=[])
|
| 1316 | + rewrite_output = rewrite_graph(output) |
| 1317 | + expected_output = mv(a, b, c) |
| 1318 | + assert equal_computations([expected_output], [rewrite_output]) |
1314 | 1319 |
|
1315 |
| - output = rewrite_graph(output) |
1316 |
| - between = vars_between([a, b, c], [output]) |
1317 |
| - for var in between: |
1318 |
| - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
1319 |
| - |
1320 |
| - # Check empty MakeVector |
| 1320 | + # Empty input should return 0 |
1321 | 1321 | mv = MakeVector(config.floatX)
|
1322 | 1322 | output = mv().sum()
|
| 1323 | + rewrite_output = rewrite_graph(output) |
| 1324 | + expected_output = pt.as_tensor(0, dtype=config.floatX) |
| 1325 | + assert equal_computations([expected_output], [rewrite_output]) |
1323 | 1326 |
|
1324 |
| - output = rewrite_graph(output) |
1325 |
| - between = vars_between([a, b, c], [output]) |
1326 |
| - for var in between: |
1327 |
| - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
1328 |
| - |
| 1327 | + # Single element input should return element value |
| 1328 | + a = scalars("a") |
1329 | 1329 | mv = MakeVector(config.floatX)
|
1330 | 1330 | output = mv(a).sum()
|
1331 |
| - |
1332 |
| - output = rewrite_graph(output) |
1333 |
| - between = vars_between([a, b, c], [output]) |
1334 |
| - for var in between: |
1335 |
| - assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) |
| 1331 | + rewrite_output = rewrite_graph(output) |
| 1332 | + expected_output = cast(a, config.floatX) |
| 1333 | + assert equal_computations([expected_output], [rewrite_output]) |
| 1334 | + |
| 1335 | + # This is a regression test for #653. Ensure that rewrite is NOT |
| 1336 | + # applied when user requests float32 |
| 1337 | + with config.change_flags(floatX="float32", warn_float64="raise"): |
| 1338 | + a, b, c = scalars("abc") |
| 1339 | + mv = MakeVector(config.floatX) |
| 1340 | + output = mv(a, b, c).sum() |
| 1341 | + rewrite_output = rewrite_graph(output) |
| 1342 | + assert equal_computations([output], [rewrite_output]) |
1336 | 1343 |
|
1337 | 1344 |
|
1338 | 1345 | @pytest.mark.parametrize(
|
|
0 commit comments