@@ -697,55 +697,95 @@ def test_variable_depends_on():
697
697
assert variable_depends_on (y , [y ])
698
698
699
699
700
- def test_truncated_graph_inputs ():
701
- """
702
- * No conditions
703
- n - n - (o)
704
-
705
- * One condition
706
- n - (c) - o
707
-
708
- * Two conditions where on depends on another, both returned
709
- (c) - (c) - o
710
-
711
- * Additional nodes are present
712
- (c) - n - o
713
- n - (n) -'
700
+ class TestTruncatedGraphInputs :
701
+ def test_basic (self ):
702
+ """
703
+ * No conditions
704
+ n - n - (o)
705
+
706
+ * One condition
707
+ n - (c) - o
708
+
709
+ * Two conditions where on depends on another, both returned
710
+ (c) - (c) - o
711
+
712
+ * Additional nodes are present
713
+ (c) - n - o
714
+ n - (n) -'
715
+
716
+ * Disconnected condition not returned
717
+ (c) - n - o
718
+ c
719
+
720
+ * Disconnected output is present and returned
721
+ (c) - (c) - o
722
+ (o)
723
+
724
+ * Condition on itself adds itself
725
+ n - (c) - (o/c)
726
+ """
727
+ x = MyVariable (1 )
728
+ x .name = "x"
729
+ y = MyVariable (1 )
730
+ y .name = "y"
731
+ z = MyVariable (1 )
732
+ z .name = "z"
733
+ x2 = MyOp (x )
734
+ x2 .name = "x2"
735
+ y2 = MyOp (y , x2 )
736
+ y2 .name = "y2"
737
+ o = MyOp (y2 )
738
+ o2 = MyOp (o )
739
+ # No conditions
740
+ assert truncated_graph_inputs ([o ]) == [o ]
741
+ # One condition
742
+ assert truncated_graph_inputs ([o2 ], [y2 ]) == [y2 ]
743
+ # Condition on itself adds itself
744
+ assert truncated_graph_inputs ([o ], [y2 , o ]) == [o , y2 ]
745
+ # Two conditions where on depends on another, both returned
746
+ assert truncated_graph_inputs ([o2 ], [y2 , o ]) == [o , y2 ]
747
+ # Additional nodes are present
748
+ assert truncated_graph_inputs ([o ], [y ]) == [x2 , y ]
749
+ # Disconnected condition
750
+ assert truncated_graph_inputs ([o2 ], [y2 , z ]) == [y2 ]
751
+ # Disconnected output is present
752
+ assert truncated_graph_inputs ([o2 , z ], [y2 ]) == [z , y2 ]
753
+
754
+ def test_repeated_input (self ):
755
+ """Test that truncated_graph_inputs does not return repeated inputs."""
756
+ x = MyVariable (1 )
757
+ x .name = "x"
758
+ y = MyVariable (1 )
759
+ y .name = "y"
760
+
761
+ trunc_inp1 = MyOp (x , y )
762
+ trunc_inp1 .name = "trunc_inp1"
763
+
764
+ trunc_inp2 = MyOp (x , y )
765
+ trunc_inp2 .name = "trunc_inp2"
766
+
767
+ o = MyOp (trunc_inp1 , trunc_inp1 , trunc_inp2 , trunc_inp2 )
768
+ o .name = "o"
769
+
770
+ assert truncated_graph_inputs ([o ], [trunc_inp1 ]) == [trunc_inp2 , trunc_inp1 ]
771
+
772
+ def test_repeated_nested_input (self ):
773
+ """Test that truncated_graph_inputs does not return repeated inputs."""
774
+ x = MyVariable (1 )
775
+ x .name = "x"
776
+ y = MyVariable (1 )
777
+ y .name = "y"
778
+
779
+ trunc_inp = MyOp (x , y )
780
+ trunc_inp .name = "trunc_inp"
781
+
782
+ o1 = MyOp (trunc_inp , trunc_inp , x , x )
783
+ o1 .name = "o1"
714
784
715
- * Disconnected condition not returned
716
- (c) - n - o
717
- c
785
+ assert truncated_graph_inputs ([o1 ], [trunc_inp ]) == [x , trunc_inp ]
718
786
719
- * Disconnected output is present and returned
720
- (c) - (c) - o
721
- (o)
787
+ # Reverse order of inputs
788
+ o2 = MyOp ( x , x , trunc_inp , trunc_inp )
789
+ o2 . name = "o2"
722
790
723
- * Condition on itself adds itself
724
- n - (c) - (o/c)
725
- """
726
- x = MyVariable (1 )
727
- x .name = "x"
728
- y = MyVariable (1 )
729
- y .name = "y"
730
- z = MyVariable (1 )
731
- z .name = "z"
732
- x2 = MyOp (x )
733
- x2 .name = "x2"
734
- y2 = MyOp (y , x2 )
735
- y2 .name = "y2"
736
- o = MyOp (y2 )
737
- o2 = MyOp (o )
738
- # No conditions
739
- assert truncated_graph_inputs ([o ]) == [o ]
740
- # One condition
741
- assert truncated_graph_inputs ([o2 ], [y2 ]) == [y2 ]
742
- # Condition on itself adds itself
743
- assert truncated_graph_inputs ([o ], [y2 , o ]) == [o , y2 ]
744
- # Two conditions where on depends on another, both returned
745
- assert truncated_graph_inputs ([o2 ], [y2 , o ]) == [o , y2 ]
746
- # Additional nodes are present
747
- assert truncated_graph_inputs ([o ], [y ]) == [x2 , y ]
748
- # Disconnected condition
749
- assert truncated_graph_inputs ([o2 ], [y2 , z ]) == [y2 ]
750
- # Disconnected output is present
751
- assert truncated_graph_inputs ([o2 , z ], [y2 ]) == [z , y2 ]
791
+ assert truncated_graph_inputs ([o2 ], [trunc_inp ]) == [trunc_inp , x ]
0 commit comments