Skip to content

Commit e904f5e

Browse files
committed
Specialize Zero Alloc
1 parent 17148fd commit e904f5e

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

pytensor/tensor/basic.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,11 @@ def c_code(self, node, name, inp, out, sub):
16591659
o_static_shape = node.outputs[0].type.shape
16601660
v_ndim = len(v_static_shape)
16611661
o_ndim = len(o_static_shape)
1662+
is_zero = (
1663+
all(node.inputs[0].type.broadcastable)
1664+
and isinstance(node.inputs[0], Constant)
1665+
and (node.inputs[0].unique_value == 0)
1666+
)
16621667
assert o_ndim == len(inp[1:])
16631668

16641669
# Declare variables
@@ -1699,16 +1704,18 @@ def c_code(self, node, name, inp, out, sub):
16991704
{fail}
17001705
}}
17011706
}}
1702-
1707+
if ({int(is_zero)} && (PyArray_IS_C_CONTIGUOUS({zz}) || PyArray_IS_F_CONTIGUOUS({zz}))){{
1708+
PyArray_FILLWBYTE({zz}, 0);
1709+
}}
17031710
// This function takes care of broadcasting
1704-
if (PyArray_CopyInto({zz}, {vv}) == -1)
1711+
else if (PyArray_CopyInto({zz}, {vv}) == -1)
17051712
{fail}
17061713
"""
17071714

17081715
return code
17091716

17101717
def c_code_cache_version(self):
1711-
return (4,)
1718+
return (5,)
17121719

17131720
def infer_shape(self, fgraph, node, input_shapes):
17141721
return [node.inputs[1:]]

0 commit comments

Comments
 (0)