Skip to content

Commit ceb891d

Browse files
Add helper to extract variable graph_input
1 parent 5e612ab commit ceb891d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

pytensor/graph/basic.py

+29
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,35 @@ def graph_inputs(
907907
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
908908

909909

910+
def required_graph_inputs(graph):
911+
"""
912+
Get the inputs into PyTensor variables
913+
914+
Parameters
915+
----------
916+
graph: PyTensor `Variable` instances
917+
Output `Variable` instances from which to search backward through
918+
owners.
919+
920+
Returns
921+
-------
922+
List of tensor variables that are input nodes with no owner, in the order
923+
found by a left-recursive depth-first search started at the nodes in `graphs`.
924+
"""
925+
from pytensor.compile.sharedvalue import SharedVariable
926+
927+
if isinstance(graph, tuple | list):
928+
graph = graph
929+
else:
930+
graph = [graph]
931+
932+
return [
933+
v
934+
for v in graph_inputs(graph)
935+
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
936+
]
937+
938+
910939
def vars_between(
911940
ins: Collection[Variable], outs: Iterable[Variable]
912941
) -> Generator[Variable, None, None]:

tests/graph/test_basic.py

+36
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor import tensor as pt
99
from pytensor.graph.basic import (
1010
Apply,
11+
Constant,
1112
NominalVariable,
1213
Variable,
1314
ancestors,
@@ -23,6 +24,7 @@
2324
io_toposort,
2425
list_of_nodes,
2526
orphans_between,
27+
required_graph_inputs,
2628
truncated_graph_inputs,
2729
variable_depends_on,
2830
vars_between,
@@ -61,6 +63,16 @@ def MyVariable(thingy):
6163
return Variable(MyType(thingy), None, None)
6264

6365

66+
def MyConstant(name, thingy, data=None):
67+
return Constant(MyType(thingy), data, name=name)
68+
69+
70+
# def MyConstant(thingy):
71+
# return Constant(MyType(thingy), None, None)
72+
# def MyVariable(thingy):
73+
# return Variable(MyType(thingy), None, None)
74+
75+
6476
class MyOp(Op):
6577
__props__ = ()
6678

@@ -75,6 +87,20 @@ def perform(self, *args, **kwargs):
7587
raise NotImplementedError("No Python implementation available.")
7688

7789

90+
class MyCustomOp(Op):
91+
__props__ = ()
92+
93+
def make_node(self, *inputs):
94+
for input in inputs:
95+
assert isinstance(input, Variable)
96+
assert isinstance(input.type, MyType)
97+
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
98+
return Apply(self, list(inputs), outputs)
99+
100+
def perform(self, *args, **kwargs):
101+
raise NotImplementedError("No Python implementation available.")
102+
103+
78104
MyOp = MyOp()
79105

80106

@@ -501,6 +527,16 @@ def test_graph_inputs():
501527
assert res_list == [r3, r1, r2]
502528

503529

530+
def test_required_graph_inputs():
531+
x = pt.fscalar()
532+
y = pt.constant(2)
533+
z = shared(1)
534+
a = pt.sum(x + y + z)
535+
536+
res = list(required_graph_inputs([a]))
537+
assert res == [x]
538+
539+
504540
def test_variables_and_orphans():
505541
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
506542
o1 = MyOp(r1, r2)

0 commit comments

Comments
 (0)