Skip to content

Commit e180927

Browse files
committed
Fix bug where ShapeFeature would create circular shape graph
1 parent f951743 commit e180927

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

pytensor/tensor/rewriting/shape.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,17 @@ def update_shape(self, r, other_r):
423423
# This mean the shape is equivalent
424424
# We do not want to do the ancestor check in those cases
425425
merged_shape.append(r_shape[i])
426-
elif r_shape[i] in ancestors([other_shape[i]]):
426+
elif any(
427+
(
428+
r_shape[i] == anc
429+
or (
430+
anc.owner
431+
and isinstance(anc.owner.op, Shape)
432+
and anc.owner.inputs[0] == r
433+
)
434+
)
435+
for anc in ancestors([other_shape[i]])
436+
):
427437
# Another case where we want to use r_shape[i] is when
428438
# other_shape[i] actually depends on r_shape[i]. In that case,
429439
# we do not want to substitute an expression with another that

tests/tensor/rewriting/test_shape.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
1616
from pytensor.graph.rewriting.utils import rewrite_graph
1717
from pytensor.graph.type import Type
18-
from pytensor.tensor.basic import as_tensor_variable
18+
from pytensor.tensor.basic import alloc, as_tensor_variable
1919
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2020
from pytensor.tensor.math import add, exp, maximum
2121
from pytensor.tensor.rewriting.basic import register_specialize
@@ -239,6 +239,25 @@ def test_no_shapeopt(self):
239239
# FIXME: This is not a good test.
240240
f([[1, 2], [2, 3]])
241241

242+
def test_shape_of_useless_alloc(self):
243+
"""Test that local_shape_to_shape_i does not create circular graph.
244+
245+
Regression test for #565
246+
"""
247+
alpha = vector(shape=(None,), dtype="float64")
248+
channel = vector(shape=(None,), dtype="float64")
249+
250+
broadcast_channel = alloc(
251+
channel,
252+
maximum(
253+
shape(alpha)[0],
254+
shape(channel)[0],
255+
),
256+
)
257+
out = shape(broadcast_channel)
258+
fn = function([alpha, channel], out)
259+
assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,)
260+
242261

243262
class TestReshape:
244263
def setup_method(self):

0 commit comments

Comments
 (0)