Skip to content

Commit efa088a

Browse files
Improve docstring for explicit_graph_inputs
1 parent d34760d commit efa088a

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

pytensor/graph/basic.py

+50
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,56 @@ def graph_inputs(
936936
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
937937

938938

939+
def explicit_graph_inputs(
940+
graph: Variable[Any, Any] | Iterable[Variable[Any, Any]],
941+
) -> Generator[Variable, None, None]:
942+
"""
943+
Get the root variables needed as inputs to a function that computes `graph`
944+
945+
Parameters
946+
----------
947+
graph : TensorVariable
948+
Output `Variable` instances for which to search backward through
949+
owners.
950+
951+
Returns
952+
-------
953+
iterable
954+
Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`.
955+
956+
Examples
957+
--------
958+
959+
.. code-block:: python
960+
961+
import pytensor
962+
import numpy as np
963+
import pytensor.tensor as pt
964+
from pytensor.graph.basic import explicit_graph_inputs
965+
966+
x = pt.vector('x')
967+
y = pt.constant(2)
968+
z = pt.mul(x*y)
969+
970+
out = list(explicit_graph_inputs(z))
971+
f = pytensor.function([x], out)
972+
eval = f(np.array([1, 2, 3]))
973+
974+
print(eval)
975+
# [array([1., 2., 3.])]
976+
"""
977+
from pytensor.compile.sharedvalue import SharedVariable
978+
979+
if isinstance(graph, Variable):
980+
graph = [graph]
981+
982+
return (
983+
v
984+
for v in graph_inputs(graph)
985+
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
986+
)
987+
988+
939989
def vars_between(
940990
ins: Collection[Variable], outs: Iterable[Variable]
941991
) -> Generator[Variable, None, None]:

tests/graph/test_basic.py

+14
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
clone,
1919
clone_get_equiv,
2020
equal_computations,
21+
explicit_graph_inputs,
2122
general_toposort,
2223
get_var_by_name,
2324
graph_inputs,
@@ -522,6 +523,19 @@ def test_graph_inputs():
522523
assert res_list == [r3, r1, r2]
523524

524525

526+
def test_explicit_graph_inputs():
527+
x = pt.fscalar()
528+
y = pt.constant(2)
529+
z = shared(1)
530+
a = pt.sum(x + y + z)
531+
b = pt.true_div(x, y)
532+
533+
res = list(explicit_graph_inputs([a]))
534+
res1 = list(explicit_graph_inputs(b))
535+
536+
assert res, res1 == [x]
537+
538+
525539
def test_variables_and_orphans():
526540
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
527541
o1 = MyOp(r1, r2)

0 commit comments

Comments
 (0)