Skip to content

Commit 8634a55

Browse files
Improve docstring for explicit_graph_inputs
1 parent 94b3ac7 commit 8634a55

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

pytensor/graph/basic.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -940,37 +940,39 @@ def explicit_graph_inputs(
940940
graph: Variable[Any, Any] | Iterable[Variable[Any, Any]],
941941
) -> Generator[Variable, None, None]:
942942
"""
943-
Get the inputs into PyTensor variables
943+
Get the root variables needed as inputs to a function that computes `graph`
944944
945945
Parameters
946946
----------
947-
graph: TensorVariable
948-
Output `Variable` instances from which to search backward through
947+
graph : TensorVariable
948+
Output `Variable` instances for which to search backward through
949949
owners.
950950
951951
Returns
952952
-------
953-
Tensor variables that are input nodes with no owner, in the order
954-
found by a left-recursive depth-first search started at the nodes in `graphs`.
953+
iterable
954+
Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`.
955955
956956
Examples
957957
--------
958958
959959
.. code-block:: python
960960
961961
import pytensor
962+
import numpy as np
962963
import pytensor.tensor as pt
964+
from pytensor.graph.basic import explicit_graph_inputs
963965
964966
x = pt.vector('x')
965967
y = pt.constant(2)
966968
z = pt.mul(x*y)
967969
968-
pytensor.dprint(graph_inputs([z]))
969-
# x [id A]
970-
# 2 [id B]
970+
out = list(explicit_graph_inputs(z))
971+
f = pytensor.function([x], out)
972+
eval = f(np.array([1, 2, 3]))
971973
972-
pytensor.dprint(explicit_graph_inputs([z]))
973-
# x [id A]
974+
print(eval)
975+
# [array([1., 2., 3.])]
974976
"""
975977
from pytensor.compile.sharedvalue import SharedVariable
976978

tests/graph/test_basic.py

+1
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def test_explicit_graph_inputs():
532532

533533
res = list(explicit_graph_inputs([a]))
534534
res1 = list(explicit_graph_inputs(b))
535+
535536
assert res, res1 == [x]
536537

537538

0 commit comments

Comments
 (0)