Skip to content

Commit 21d723b

Browse files
committed
Fix bug in truncated_graph_inputs
It could return duplicated truncated inputs before the changes, as well as return wrong outputs based on the nodes input order
1 parent 12ca8bd commit 21d723b

File tree

2 files changed

+105
-64
lines changed

2 files changed

+105
-64
lines changed

pytensor/graph/basic.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,47 +1056,48 @@ def truncated_graph_inputs(
10561056
truncated_inputs.append(node)
10571057
# no more actions are needed
10581058
return truncated_inputs
1059+
10591060
blockers: Set[Variable] = set(ancestors_to_include)
10601061
# enforce O(1) check for node in ancestors to include
10611062
ancestors_to_include = blockers.copy()
10621063

10631064
while candidates:
10641065
# on any new candidate
10651066
node = candidates.pop()
1066-
# check if the node is independent, never go above blockers
1067+
1068+
# There was a repeated reference to this node, we have already investigated it
1069+
if node in truncated_inputs:
1070+
continue
1071+
1072+
# check if the node is independent, never go above blockers;
10671073
# blockers are independent nodes and ancestors to include
10681074
if node in ancestors_to_include:
10691075
# The case where node is in ancestors to include so we check if it depends on others
10701076
# it should be removed from the blockers to check against the rest
1071-
dependent = variable_depends_on(node, blockers - {node})
1077+
dependent = variable_depends_on(node, ancestors_to_include - {node})
10721078
# ancestors to include that are present in the graph (not disconnected)
10731079
# should be added to truncated_inputs
10741080
truncated_inputs.append(node)
10751081
if dependent:
1076-
# if the ancestors to include is still dependent we need to go above,
1077-
# the search is not yet finished
1078-
# the node _has_ to have owner to be dependent
1079-
# so we do not check it
1080-
# and populate search to go above
1082+
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
10811083
# owner can never be None for a dependent node
10821084
candidates.extend(node.owner.inputs)
10831085
else:
10841086
# A regular node to check
10851087
dependent = variable_depends_on(node, blockers)
1086-
# all regular nodes fall to blockes
1088+
# all regular nodes fall to blockers
10871089
# 1. it is dependent - further search irrelevant
10881090
# 2. it is independent - the search node is inside the closure
10891091
blockers.add(node)
10901092
# if we've found an independent node and it is not in blockers so far
1091-
# it is a new indepenent node not present in ancestors to include
1092-
if not dependent:
1093-
# we've found an independent node
1094-
# do not search beyond
1095-
truncated_inputs.append(node)
1096-
else:
1097-
# populate search otherwise
1093+
# it is a new independent node not present in ancestors to include
1094+
if dependent:
1095+
# populate search if it's not an independent node
10981096
# owner can never be None for a dependent node
10991097
candidates.extend(node.owner.inputs)
1098+
else:
1099+
# otherwise, do not search beyond
1100+
truncated_inputs.append(node)
11001101
return truncated_inputs
11011102

11021103

tests/graph/test_basic.py

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -697,55 +697,95 @@ def test_variable_depends_on():
697697
assert variable_depends_on(y, [y])
698698

699699

700-
def test_truncated_graph_inputs():
701-
"""
702-
* No conditions
703-
n - n - (o)
704-
705-
* One condition
706-
n - (c) - o
707-
708-
* Two conditions where on depends on another, both returned
709-
(c) - (c) - o
710-
711-
* Additional nodes are present
712-
(c) - n - o
713-
n - (n) -'
700+
class TestTruncatedGraphInputs:
701+
def test_basic(self):
702+
"""
703+
* No conditions
704+
n - n - (o)
705+
706+
* One condition
707+
n - (c) - o
708+
709+
* Two conditions where on depends on another, both returned
710+
(c) - (c) - o
711+
712+
* Additional nodes are present
713+
(c) - n - o
714+
n - (n) -'
715+
716+
* Disconnected condition not returned
717+
(c) - n - o
718+
c
719+
720+
* Disconnected output is present and returned
721+
(c) - (c) - o
722+
(o)
723+
724+
* Condition on itself adds itself
725+
n - (c) - (o/c)
726+
"""
727+
x = MyVariable(1)
728+
x.name = "x"
729+
y = MyVariable(1)
730+
y.name = "y"
731+
z = MyVariable(1)
732+
z.name = "z"
733+
x2 = MyOp(x)
734+
x2.name = "x2"
735+
y2 = MyOp(y, x2)
736+
y2.name = "y2"
737+
o = MyOp(y2)
738+
o2 = MyOp(o)
739+
# No conditions
740+
assert truncated_graph_inputs([o]) == [o]
741+
# One condition
742+
assert truncated_graph_inputs([o2], [y2]) == [y2]
743+
# Condition on itself adds itself
744+
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
745+
# Two conditions where on depends on another, both returned
746+
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
747+
# Additional nodes are present
748+
assert truncated_graph_inputs([o], [y]) == [x2, y]
749+
# Disconnected condition
750+
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
751+
# Disconnected output is present
752+
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
753+
754+
def test_repeated_input(self):
755+
"""Test that truncated_graph_inputs does not return repeated inputs."""
756+
x = MyVariable(1)
757+
x.name = "x"
758+
y = MyVariable(1)
759+
y.name = "y"
760+
761+
trunc_inp1 = MyOp(x, y)
762+
trunc_inp1.name = "trunc_inp1"
763+
764+
trunc_inp2 = MyOp(x, y)
765+
trunc_inp2.name = "trunc_inp2"
766+
767+
o = MyOp(trunc_inp1, trunc_inp1, trunc_inp2, trunc_inp2)
768+
o.name = "o"
769+
770+
assert truncated_graph_inputs([o], [trunc_inp1]) == [trunc_inp2, trunc_inp1]
771+
772+
def test_repeated_nested_input(self):
773+
"""Test that truncated_graph_inputs does not return repeated inputs."""
774+
x = MyVariable(1)
775+
x.name = "x"
776+
y = MyVariable(1)
777+
y.name = "y"
778+
779+
trunc_inp = MyOp(x, y)
780+
trunc_inp.name = "trunc_inp"
781+
782+
o1 = MyOp(trunc_inp, trunc_inp, x, x)
783+
o1.name = "o1"
714784

715-
* Disconnected condition not returned
716-
(c) - n - o
717-
c
785+
assert truncated_graph_inputs([o1], [trunc_inp]) == [x, trunc_inp]
718786

719-
* Disconnected output is present and returned
720-
(c) - (c) - o
721-
(o)
787+
# Reverse order of inputs
788+
o2 = MyOp(x, x, trunc_inp, trunc_inp)
789+
o2.name = "o2"
722790

723-
* Condition on itself adds itself
724-
n - (c) - (o/c)
725-
"""
726-
x = MyVariable(1)
727-
x.name = "x"
728-
y = MyVariable(1)
729-
y.name = "y"
730-
z = MyVariable(1)
731-
z.name = "z"
732-
x2 = MyOp(x)
733-
x2.name = "x2"
734-
y2 = MyOp(y, x2)
735-
y2.name = "y2"
736-
o = MyOp(y2)
737-
o2 = MyOp(o)
738-
# No conditions
739-
assert truncated_graph_inputs([o]) == [o]
740-
# One condition
741-
assert truncated_graph_inputs([o2], [y2]) == [y2]
742-
# Condition on itself adds itself
743-
assert truncated_graph_inputs([o], [y2, o]) == [o, y2]
744-
# Two conditions where on depends on another, both returned
745-
assert truncated_graph_inputs([o2], [y2, o]) == [o, y2]
746-
# Additional nodes are present
747-
assert truncated_graph_inputs([o], [y]) == [x2, y]
748-
# Disconnected condition
749-
assert truncated_graph_inputs([o2], [y2, z]) == [y2]
750-
# Disconnected output is present
751-
assert truncated_graph_inputs([o2, z], [y2]) == [z, y2]
791+
assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]

0 commit comments

Comments
 (0)