Skip to content

Commit 2315e69

Browse files
authored
support on_unused_input for string parameter names in eval (#1085)
1 parent d9d8dba commit 2315e69

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

pytensor/graph/basic.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -616,16 +616,20 @@ def eval(
616616
"""
617617
from pytensor.compile.function import function
618618

619+
ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn")
620+
619621
def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
620622
new_input_to_values = {}
621623
for key, value in inputs_to_values.items():
622624
if isinstance(key, str):
623625
matching_vars = get_var_by_name([self], key)
624626
if not matching_vars:
625-
raise ValueError(f"{key} not found in graph")
627+
if not ignore_unused_input:
628+
raise ValueError(f"{key} not found in graph")
626629
elif len(matching_vars) > 1:
627630
raise ValueError(f"Found multiple variables with name {key}")
628-
new_input_to_values[matching_vars[0]] = value
631+
else:
632+
new_input_to_values[matching_vars[0]] = value
629633
else:
630634
new_input_to_values[key] = value
631635
return new_input_to_values

tests/graph/test_basic.py

+4
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ def test_eval_kwargs(self):
367367
self.w.eval({self.z: 3, self.x: 2.5})
368368
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0
369369

370+
# regression test for https://github.com/pymc-devs/pytensor/issues/1084
371+
q = self.x + 1
372+
assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0
373+
370374
@pytest.mark.filterwarnings("error")
371375
def test_eval_unashable_kwargs(self):
372376
y_repl = constant(2.0, dtype="floatX")

0 commit comments

Comments
 (0)