Skip to content

Commit 9f7ea3e

Browse files
Add type_hints for required_graph_inputs�
1 parent e19bb34 commit 9f7ea3e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pytensor/graph/basic.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -908,7 +908,9 @@ def graph_inputs(
908908

909909

910910
# def required_graph_inputs(graph:Variable | Iterator[Variable])-> Generator[Variable, None, None]:
911-
def required_graph_inputs(graph: list[Variable[Any, Any]]) -> list[Variable[Any, Any]]:
911+
def required_graph_inputs(
912+
graph: Variable[Any, Any] | Iterable[Variable[Any, Any]],
913+
) -> list[Variable[Any, Any]]:
912914
"""
913915
Get the inputs into PyTensor variables
914916
@@ -952,8 +954,8 @@ def required_graph_inputs(graph: list[Variable[Any, Any]]) -> list[Variable[Any,
952954
"""
953955
from pytensor.compile.sharedvalue import SharedVariable
954956

955-
# if isinstance(graph, Variable):
956-
# graph = [graph]
957+
if isinstance(graph, Variable):
958+
graph = [graph]
957959

958960
return [
959961
v

0 commit comments

Comments
 (0)