8
8
from pytensor import tensor as pt
9
9
from pytensor .graph .basic import (
10
10
Apply ,
11
+ Constant ,
11
12
NominalVariable ,
12
13
Variable ,
13
14
ancestors ,
23
24
io_toposort ,
24
25
list_of_nodes ,
25
26
orphans_between ,
27
+ required_graph_inputs ,
26
28
truncated_graph_inputs ,
27
29
variable_depends_on ,
28
30
vars_between ,
@@ -61,6 +63,16 @@ def MyVariable(thingy):
61
63
return Variable (MyType (thingy ), None , None )
62
64
63
65
66
+ def MyConstant (name , thingy , data = None ):
67
+ return Constant (MyType (thingy ), data , name = name )
68
+
69
+
70
+ # def MyConstant(thingy):
71
+ # return Constant(MyType(thingy), None, None)
72
+ # def MyVariable(thingy):
73
+ # return Variable(MyType(thingy), None, None)
74
+
75
+
64
76
class MyOp (Op ):
65
77
__props__ = ()
66
78
@@ -75,6 +87,20 @@ def perform(self, *args, **kwargs):
75
87
raise NotImplementedError ("No Python implementation available." )
76
88
77
89
90
+ class MyCustomOp (Op ):
91
+ __props__ = ()
92
+
93
+ def make_node (self , * inputs ):
94
+ for input in inputs :
95
+ assert isinstance (input , Variable )
96
+ assert isinstance (input .type , MyType )
97
+ outputs = [MyVariable (sum (input .type .thingy for input in inputs ))]
98
+ return Apply (self , list (inputs ), outputs )
99
+
100
+ def perform (self , * args , ** kwargs ):
101
+ raise NotImplementedError ("No Python implementation available." )
102
+
103
+
78
104
MyOp = MyOp ()
79
105
80
106
@@ -501,6 +527,16 @@ def test_graph_inputs():
501
527
assert res_list == [r3 , r1 , r2 ]
502
528
503
529
530
+ def test_required_graph_inputs ():
531
+ x = pt .fscalar ()
532
+ y = pt .constant (2 )
533
+ z = shared (1 )
534
+ a = pt .sum (x + y + z )
535
+
536
+ res = list (required_graph_inputs ([a ]))
537
+ assert res == [x ]
538
+
539
+
504
540
def test_variables_and_orphans ():
505
541
r1 , r2 , r3 = MyVariable (1 ), MyVariable (2 ), MyVariable (3 )
506
542
o1 = MyOp (r1 , r2 )
0 commit comments