Skip to content

Commit 41114da

Browse files
committed
Specialize Zero Alloc
1 parent 5fadf46 commit 41114da

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

pytensor/tensor/basic.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,14 @@ def _check_runtime_broadcast(node, value, shape):
16341634
if v_static_dim is None and value_dim == 1 and out_dim != 1:
16351635
raise ValueError(Alloc._runtime_broadcast_error_msg)
16361636

1637+
@staticmethod
1638+
def value_is_scalar_zero(x: TensorVariable) -> bool:
1639+
return (
1640+
all(x.type.broadcastable)
1641+
and isinstance(x, Constant)
1642+
and (x.unique_value == 0)
1643+
)
1644+
16371645
def perform(self, node, inputs, out_):
16381646
(out,) = out_
16391647
v = inputs[0]
@@ -1659,6 +1667,7 @@ def c_code(self, node, name, inp, out, sub):
16591667
o_static_shape = node.outputs[0].type.shape
16601668
v_ndim = len(v_static_shape)
16611669
o_ndim = len(o_static_shape)
1670+
is_zero = self.value_is_scalar_zero(node.inputs[0])
16621671
assert o_ndim == len(inp[1:])
16631672

16641673
# Declare variables
@@ -1699,16 +1708,18 @@ def c_code(self, node, name, inp, out, sub):
16991708
{fail}
17001709
}}
17011710
}}
1702-
1711+
if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
1712+
PyArray_FILLWBYTE({zz}, 0);
1713+
}}
17031714
// This function takes care of broadcasting
1704-
if (PyArray_CopyInto({zz}, {vv}) == -1)
1715+
else if (PyArray_CopyInto({zz}, {vv}) == -1)
17051716
{fail}
17061717
"""
17071718

17081719
return code
17091720

17101721
def c_code_cache_version(self):
1711-
return (4,)
1722+
return (5,)
17121723

17131724
def infer_shape(self, fgraph, node, input_shapes):
17141725
return [node.inputs[1:]]

0 commit comments

Comments
 (0)