@@ -936,6 +936,56 @@ def graph_inputs(
936
936
yield from (r for r in ancestors (graphs , blockers ) if r .owner is None )
937
937
938
938
939
+ def explicit_graph_inputs (
940
+ graph : Variable [Any , Any ] | Iterable [Variable [Any , Any ]],
941
+ ) -> Generator [Variable , None , None ]:
942
+ """
943
+ Get the root variables needed as inputs to a function that computes `graph`
944
+
945
+ Parameters
946
+ ----------
947
+ graph : TensorVariable
948
+ Output `Variable` instances for which to search backward through
949
+ owners.
950
+
951
+ Returns
952
+ -------
953
+ iterable
954
+ Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`.
955
+
956
+ Examples
957
+ --------
958
+
959
+ .. code-block:: python
960
+
961
+ import pytensor
962
+ import numpy as np
963
+ import pytensor.tensor as pt
964
+ from pytensor.graph.basic import explicit_graph_inputs
965
+
966
+ x = pt.vector('x')
967
+ y = pt.constant(2)
968
+ z = pt.mul(x*y)
969
+
970
+ out = list(explicit_graph_inputs(z))
971
+ f = pytensor.function([x], out)
972
+ eval = f(np.array([1, 2, 3]))
973
+
974
+ print(eval)
975
+ # [array([1., 2., 3.])]
976
+ """
977
+ from pytensor .compile .sharedvalue import SharedVariable
978
+
979
+ if isinstance (graph , Variable ):
980
+ graph = [graph ]
981
+
982
+ return (
983
+ v
984
+ for v in graph_inputs (graph )
985
+ if isinstance (v , Variable ) and not isinstance (v , Constant | SharedVariable )
986
+ )
987
+
988
+
939
989
def vars_between (
940
990
ins : Collection [Variable ], outs : Iterable [Variable ]
941
991
) -> Generator [Variable , None , None ]:
0 commit comments