@@ -1634,6 +1634,14 @@ def _check_runtime_broadcast(node, value, shape):
1634
1634
if v_static_dim is None and value_dim == 1 and out_dim != 1 :
1635
1635
raise ValueError (Alloc ._runtime_broadcast_error_msg )
1636
1636
1637
+ @staticmethod
1638
+ def value_is_scalar_zero (x : TensorVariable ) -> bool :
1639
+ return (
1640
+ all (x .type .broadcastable )
1641
+ and isinstance (x , Constant )
1642
+ and (x .unique_value == 0 )
1643
+ )
1644
+
1637
1645
def perform (self , node , inputs , out_ ):
1638
1646
(out ,) = out_
1639
1647
v = inputs [0 ]
@@ -1659,6 +1667,7 @@ def c_code(self, node, name, inp, out, sub):
1659
1667
o_static_shape = node .outputs [0 ].type .shape
1660
1668
v_ndim = len (v_static_shape )
1661
1669
o_ndim = len (o_static_shape )
1670
+ is_zero = self .value_is_scalar_zero (node .inputs [0 ])
1662
1671
assert o_ndim == len (inp [1 :])
1663
1672
1664
1673
# Declare variables
@@ -1699,16 +1708,18 @@ def c_code(self, node, name, inp, out, sub):
1699
1708
{ fail }
1700
1709
}}
1701
1710
}}
1702
-
1711
+ if ({ int (is_zero )} && (PyArray_IS_C_CONTIGUOUS({ zz } ) || PyArray_IS_F_CONTIGUOUS({ zz } ))){{
1712
+ PyArray_FILLWBYTE({ zz } , 0);
1713
+ }}
1703
1714
// This function takes care of broadcasting
1704
- if (PyArray_CopyInto({ zz } , { vv } ) == -1)
1715
+ else if (PyArray_CopyInto({ zz } , { vv } ) == -1)
1705
1716
{ fail }
1706
1717
"""
1707
1718
1708
1719
return code
1709
1720
1710
1721
def c_code_cache_version (self ):
1711
- return (4 ,)
1722
+ return (5 ,)
1712
1723
1713
1724
def infer_shape (self , fgraph , node , input_shapes ):
1714
1725
return [node .inputs [1 :]]
0 commit comments