Skip to content

Commit e19bb34

Browse files
store changes for graph
1 parent 13e6f87 commit e19bb34

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

pytensor/graph/basic.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,8 @@ def graph_inputs(
907907
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
908908

909909

910-
def required_graph_inputs(graph):
910+
# def required_graph_inputs(graph:Variable | Iterator[Variable])-> Generator[Variable, None, None]:
911+
def required_graph_inputs(graph: list[Variable[Any, Any]]) -> list[Variable[Any, Any]]:
911912
"""
912913
Get the inputs into PyTensor variables
913914
@@ -921,13 +922,38 @@ def required_graph_inputs(graph):
921922
-------
922923
List of tensor variables that are input nodes with no owner, in the order
923924
found by a left-recursive depth-first search started at the nodes in `graphs`.
925+
926+
Examples
927+
--------
928+
>>> import pytensor as pt
929+
>>> x=pt.vector('x')
930+
>>> y=pt.constant('y')
931+
>>> z = pt.mul(x*y)
932+
>>> required_graph_inputs([a])
933+
[[[ 0 1 2 3]
934+
[ 4 5 6 7]
935+
[ 8 9 10 11]]
936+
937+
[[12 13 14 15]
938+
[16 17 18 19]
939+
[20 21 22 23]]]
940+
941+
942+
>>> pt.matrix_transpose(x).eval()
943+
[[[ 0 4 8]
944+
[ 1 5 9]
945+
[ 2 6 10]
946+
[ 3 7 11]]
947+
948+
[[12 16 20]
949+
[13 17 21]
950+
[14 18 22]
951+
[15 19 23]]]
924952
"""
925953
from pytensor.compile.sharedvalue import SharedVariable
926954

927-
if isinstance(graph, tuple | list):
928-
graph = graph
929-
else:
930-
graph = [graph]
955+
# if isinstance(graph, Variable):
956+
# graph = [graph]
931957

932958
return [
933959
v

tests/graph/test_basic.py

-25
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from pytensor import tensor as pt
99
from pytensor.graph.basic import (
1010
Apply,
11-
Constant,
1211
NominalVariable,
1312
Variable,
1413
ancestors,
@@ -63,16 +62,6 @@ def MyVariable(thingy):
6362
return Variable(MyType(thingy), None, None)
6463

6564

66-
def MyConstant(name, thingy, data=None):
67-
return Constant(MyType(thingy), data, name=name)
68-
69-
70-
# def MyConstant(thingy):
71-
# return Constant(MyType(thingy), None, None)
72-
# def MyVariable(thingy):
73-
# return Variable(MyType(thingy), None, None)
74-
75-
7665
class MyOp(Op):
7766
__props__ = ()
7867

@@ -87,20 +76,6 @@ def perform(self, *args, **kwargs):
8776
raise NotImplementedError("No Python implementation available.")
8877

8978

90-
class MyCustomOp(Op):
91-
__props__ = ()
92-
93-
def make_node(self, *inputs):
94-
for input in inputs:
95-
assert isinstance(input, Variable)
96-
assert isinstance(input.type, MyType)
97-
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
98-
return Apply(self, list(inputs), outputs)
99-
100-
def perform(self, *args, **kwargs):
101-
raise NotImplementedError("No Python implementation available.")
102-
103-
10479
MyOp = MyOp()
10580

10681

0 commit comments

Comments
 (0)