Skip to content

Commit 5db0d83

Browse files
committed
Be consistent about second vs alloc in rewrites
1 parent 67519be commit 5db0d83

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
1-
""" Tensor optimizations addressing the ops in basic.py."""
1+
""" Tensor optimizations addressing the ops in basic.py.
2+
3+
Notes
4+
-----
5+
There are two ways of broadcasting arrays:
6+
second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape))
7+
8+
The second can be more efficient because x doesn't usually need to be computed when we only want its shape.
9+
It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation).
10+
11+
However, the first one is easier to reason about.
12+
Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise.
13+
The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one.
14+
15+
As an example contrast rewriting the following two equivalent graphs
16+
17+
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y
18+
second(y, x) + second(x, y) -> x + y
19+
20+
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later,
21+
via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites.
22+
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important.
23+
"""
224

325
import logging
426
from typing import TYPE_CHECKING, Optional, Union

pytensor/tensor/rewriting/math.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
cast,
3131
constant,
3232
extract_constant,
33-
fill,
3433
get_underlying_scalar_constant_value,
3534
ones_like,
3635
switch,
@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node):
20412040
@register_specialize
20422041
@node_rewriter([at_pow])
20432042
def local_pow_specialize(fgraph, node):
2044-
# here, we are past the point of canonicalization, so we don't want
2045-
# to put in un-necessary fills.
20462043
if node.op == at_pow:
20472044
# the idea here is that we have pow(x, y)
20482045
odtype = node.outputs[0].dtype
@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node):
20572054
if np.all(y == 1):
20582055
rval = [xsym]
20592056
if np.all(y == 0):
2060-
rval = [fill(xsym, np.asarray(1, dtype=odtype))]
2057+
rval = [alloc_like(1, xsym, fgraph)]
20612058
if np.all(y == 0.5):
20622059
rval = [sqrt(xsym)]
20632060
if np.all(y == -0.5):
@@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node):
21582155
mul(-1, x, y) -/-> neg(mul(x, y))
21592156
21602157
"""
2161-
# here, we are past the point of canonicalization, so we don't
2162-
# want to put in un-necessary fills.
2163-
#
2158+
21642159
# at this point [post canonicalize], mul() may have many inputs.
21652160
if node.op == mul:
21662161
# the idea here is that we have pow(x, y)
@@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node):
22212216

22222217
@register_specialize
22232218
@node_rewriter([add])
2224-
def local_add_specialize(fgraph, node):
2225-
"""Remove zeros from ``add``s.
2226-
2227-
TODO: This should be a canonicalization, no?
2228-
"""
2229-
# here, we are past the point of canonicalization, so we don't want
2230-
# to put in un-necessary fills.
2231-
if node.op != add:
2232-
return False
2233-
2219+
def local_add_remove_zeros(fgraph, node):
22342220
new_inputs = []
22352221
for inp in node.inputs:
22362222
try:
@@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node):
22532239
# Reuse call to constant for cache()
22542240
cst = constant(np.zeros((1,) * ndim, dtype=dtype))
22552241
assert cst.type.broadcastable == (True,) * ndim
2256-
return [broadcast_arrays(cst, *node.inputs)[0]]
2242+
return [alloc_like(cst, node_output, fgraph)]
22572243

22582244
if len(new_inputs) == 1:
2259-
ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]]
2245+
ret = [alloc_like(new_inputs[0], node_output, fgraph)]
22602246
else:
2261-
ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]]
2247+
ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
22622248

22632249
# The dtype should not be changed. It can happen if the input
22642250
# that was forcing upcasting was equal to 0.
@@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node):
23762362
ninp = nonconsts[0]
23772363
if ninp.dtype != log_arg.type.dtype:
23782364
ninp = ninp.astype(node.outputs[0].dtype)
2379-
return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]]
2365+
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
23802366

23812367
elif log_arg.owner and log_arg.owner.op == sub:
23822368
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
@@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node):
35723558
if nonconsts[0].owner and nonconsts[0].owner.op == exp:
35733559
if scalars_ and np.allclose(np.sum(scalars_), 1):
35743560
out = [
3575-
broadcast_arrays(
3561+
alloc_like(
35763562
sigmoid(neg(nonconsts[0].owner.inputs[0])),
3577-
*scalar_inputs,
3578-
)[0]
3563+
node.outputs[0],
3564+
fgraph,
3565+
)
35793566
]
35803567
# keep combined stack traces of
35813568
# exp(x): nonconsts[0],

0 commit comments

Comments
 (0)