Skip to content

Commit 8c157a2

Browse files
committed
Fix local_fill_sink rewrite for multiple output Elemwise Ops
The changes get rid of the eager sink at the local node rewriter level. This was actually not working because the nested replacements referenced variables that were never part of the original fgraph and those replacements were being ignored altogether. Instead we wrap the rewrite in an in2out that will safely achieve the intended behavior.
1 parent d80c0bf commit 8c157a2

File tree

2 files changed

+49
-34
lines changed

2 files changed

+49
-34
lines changed

pytensor/tensor/rewriting/basic.py

+31-34
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from pytensor.graph.rewriting.db import RewriteDatabase
4343
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
44+
from pytensor.scalar.basic import Second
4445
from pytensor.tensor.basic import (
4546
Alloc,
4647
AllocEmpty,
@@ -320,56 +321,52 @@ def dimshuffled_alloc(i):
320321
return new_outs
321322

322323

323-
@register_canonicalize("shape_unsafe")
324324
@node_rewriter([Elemwise])
325325
def local_fill_sink(fgraph, node):
326326
"""
327327
f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e)))
328328
f need to be an elemwise that isn't a fill.
329329
"""
330-
if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill:
330+
if isinstance(node.op.scalar_op, Second):
331331
return False
332+
332333
models = []
333334
inputs = []
334335
for inp in node.inputs:
335336
if inp.owner and inp.owner.op == fill:
336-
models.append(inp.owner.inputs[0])
337-
inputs.append(inp.owner.inputs[1])
337+
a, b = inp.owner.inputs
338+
if b.type.dtype != inp.dtype:
339+
# The input was implicitly casted by the fill operation
340+
b = b.cast(inp.dtype)
341+
models.append(a)
342+
inputs.append(b)
338343
else:
339344
inputs.append(inp)
345+
340346
if not models:
341347
return False
342-
c = node.op(*inputs)
343-
for model in models:
344-
if (
345-
model.type.dtype != c.type.dtype
346-
or model.type.broadcastable != c.type.broadcastable
347-
):
348-
c = fill(model, c)
349348

350-
# The newly created node c doesn't has 'clients',
351-
# so this iteration is took place with node.outputs[0]
352-
# TODO: This should just be a WalkingGraphRewrite!
353-
replacements = {node.outputs[0]: c}
354-
for client, cl_idx in fgraph.clients[node.outputs[0]]:
355-
if (
356-
hasattr(client, "op")
357-
and isinstance(client.op, Elemwise)
358-
and client.op != fill
359-
):
360-
client_inputs = client.inputs[:]
361-
client_inputs[cl_idx] = c
362-
new_client = client.op(*client_inputs)
363-
364-
# Add clients to new_client
365-
fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[
366-
client.outputs[0]
367-
]
368-
r = local_fill_sink.transform(fgraph, new_client.owner)
369-
if not r:
370-
continue
371-
replacements.update(r)
372-
return replacements
349+
outputs = node.op.make_node(*inputs).outputs
350+
351+
# Check if we need to propagate the fill to the new outputs
352+
# It's enough to check the first output, as Elemwise outputs must all have the same shapes
353+
# Note: There are orderings that may require fewer fills.
354+
old_bcast_pattern = node.outputs[0].type.broadcastable
355+
models_iter = iter(models)
356+
while old_bcast_pattern != outputs[0].type.broadcastable:
357+
model = next(models_iter)
358+
# Only apply this model if it would actually do anything
359+
if broadcasted_by(outputs[0], model):
360+
outputs = [fill(model, output) for output in outputs]
361+
362+
return outputs
363+
364+
365+
# The rewrite is wrapped in an in2out GraphRewriter
366+
# so that fill can be sinked until the terminal nodes in a single pass through the graph
367+
# without triggering other rewrites after each local substitution
368+
topological_fill_sink = in2out(local_fill_sink)
369+
register_canonicalize(topological_fill_sink, "shape_unsafe")
373370

374371

375372
@register_specialize("shape_unsafe")

tests/tensor/rewriting/test_basic.py

+18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.graph.rewriting.utils import rewrite_graph
2020
from pytensor.printing import debugprint, pprint
2121
from pytensor.raise_op import Assert, CheckAndRaise
22+
from pytensor.scalar import Composite, float64
2223
from pytensor.tensor.basic import (
2324
Alloc,
2425
Join,
@@ -64,6 +65,7 @@
6465
local_merge_alloc,
6566
local_useless_alloc,
6667
local_useless_elemwise,
68+
topological_fill_sink,
6769
)
6870
from pytensor.tensor.rewriting.math import local_lift_transpose_through_dot
6971
from pytensor.tensor.rewriting.shape import ShapeFeature
@@ -1992,3 +1994,19 @@ def test_shape_unsafe_tag():
19921994
fn = function([x, y], out, mode=mode.excluding("shape_unsafe"))
19931995
with pytest.raises(ValueError):
19941996
fn([0, 1], [2, 3, 4]), [0, 1]
1997+
1998+
1999+
def test_topological_fill_sink_multi_output_client():
2000+
x = float64("x")
2001+
elem_op_with_2_outputs = Elemwise(Composite([x], [x + 1, x + 2]))
2002+
2003+
x = pt.vector("x", shape=(1,))
2004+
z = pt.vector("z", shape=(None,))
2005+
bcast_x = pt.full_like(z, x)
2006+
out = pt.add(*elem_op_with_2_outputs(pt.exp(bcast_x)))
2007+
2008+
fg = FunctionGraph([x, z], [out], copy_inputs=False)
2009+
topological_fill_sink.rewrite(fg)
2010+
[new_out] = fg.outputs
2011+
expected_out = pt.full_like(z, pt.add(*elem_op_with_2_outputs(pt.exp(x))))
2012+
assert equal_computations([new_out], [expected_out])

0 commit comments

Comments
 (0)