File tree Expand file tree Collapse file tree 2 files changed +25
-1
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -980,7 +980,14 @@ def local_sum_make_vector(fgraph, node):
980
980
elements = array .owner .inputs
981
981
acc_dtype = node .op .acc_dtype
982
982
out_dtype = node .op .dtype
983
- element_sum = cast (add (* [cast (value , acc_dtype ) for value in elements ]), out_dtype )
983
+ if len (elements ) == 0 :
984
+ element_sum = zeros (dtype = out_dtype , shape = ())
985
+ elif len (elements ) == 1 :
986
+ element_sum = cast (elements [0 ], out_dtype )
987
+ else :
988
+ element_sum = cast (
989
+ add (* [cast (value , acc_dtype ) for value in elements ]), out_dtype
990
+ )
984
991
985
992
return [element_sum ]
986
993
Original file line number Diff line number Diff line change @@ -1321,6 +1321,23 @@ def test_local_sum_make_vector():
1321
1321
for var in between :
1322
1322
assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
1323
1323
1324
+ # Check empty MakeVector
1325
+ mv = MakeVector (config .floatX )
1326
+ output = mv ().sum ()
1327
+
1328
+ output = rewrite_graph (output )
1329
+ between = vars_between ([a , b , c ], [output ])
1330
+ for var in between :
1331
+ assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
1332
+
1333
+ mv = MakeVector (config .floatX )
1334
+ output = mv (a ).sum ()
1335
+
1336
+ output = rewrite_graph (output )
1337
+ between = vars_between ([a , b , c ], [output ])
1338
+ for var in between :
1339
+ assert (var .owner is None ) or (not isinstance (var .owner .op , Sum ))
1340
+
1324
1341
1325
1342
@pytest .mark .parametrize (
1326
1343
"dtype" ,
You can’t perform that action at this time.
0 commit comments