Skip to content

Commit 9f6f048

Browse files
committed
fix(rewrite): Handle sum of empty make vector
1 parent 2870874 commit 9f6f048

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,14 @@ def local_sum_make_vector(fgraph, node):
980980
elements = array.owner.inputs
981981
acc_dtype = node.op.acc_dtype
982982
out_dtype = node.op.dtype
983-
element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype)
983+
if len(elements) == 0:
984+
element_sum = zeros(dtype=out_dtype, shape=())
985+
elif len(elements) == 1:
986+
element_sum = cast(elements[0], out_dtype)
987+
else:
988+
element_sum = cast(
989+
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
990+
)
984991

985992
return [element_sum]
986993

tests/tensor/rewriting/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,23 @@ def test_local_sum_make_vector():
13211321
for var in between:
13221322
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
13231323

1324+
# Check empty MakeVector
1325+
mv = MakeVector(config.floatX)
1326+
output = mv().sum()
1327+
1328+
output = rewrite_graph(output)
1329+
between = vars_between([a, b, c], [output])
1330+
for var in between:
1331+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1332+
1333+
mv = MakeVector(config.floatX)
1334+
output = mv(a).sum()
1335+
1336+
output = rewrite_graph(output)
1337+
between = vars_between([a, b, c], [output])
1338+
for var in between:
1339+
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
1340+
13241341

13251342
@pytest.mark.parametrize(
13261343
"dtype",

0 commit comments

Comments
 (0)