Skip to content

Commit 1117ea5

Browse files
committed
Allow keyword arguments in eval method
1 parent 98070db commit 1117ea5

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

pytensor/graph/basic.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -555,13 +555,20 @@ def get_parents(self):
555555
return [self.owner]
556556
return []
557557

558-
def eval(self, inputs_to_values=None):
559-
r"""Evaluate the `Variable`.
558+
def eval(
559+
self,
560+
inputs_to_values: dict[Union["Variable", str], Any] | None = None,
561+
**kwargs,
562+
):
563+
r"""Evaluate the `Variable` given a set of values for its inputs.
560564
561565
Parameters
562566
----------
563567
inputs_to_values :
564-
A dictionary mapping PyTensor `Variable`\s to values.
568+
A dictionary mapping PyTensor `Variable`\s or names to values.
569+
Not needed if variable has no required inputs.
570+
kwargs :
571+
Optional keyword arguments to pass to the underlying `pytensor.function`
565572
566573
Examples
567574
--------
@@ -591,10 +598,7 @@ def eval(self, inputs_to_values=None):
591598
"""
592599
from pytensor.compile.function import function
593600

594-
if inputs_to_values is None:
595-
inputs_to_values = {}
596-
597-
def convert_string_keys_to_variables(input_to_values):
601+
def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
598602
new_input_to_values = {}
599603
for key, value in inputs_to_values.items():
600604
if isinstance(key, str):
@@ -608,19 +612,32 @@ def convert_string_keys_to_variables(input_to_values):
608612
new_input_to_values[key] = value
609613
return new_input_to_values
610614

611-
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
615+
parsed_inputs_to_values: dict[Variable, Any] = {}
616+
if inputs_to_values is not None:
617+
parsed_inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
612618

613619
if not hasattr(self, "_fn_cache"):
614-
self._fn_cache = dict()
620+
self._fn_cache: dict = dict()
615621

616-
inputs = tuple(sorted(inputs_to_values.keys(), key=id))
617-
if inputs not in self._fn_cache:
618-
self._fn_cache[inputs] = function(inputs, self)
619-
args = [inputs_to_values[param] for param in inputs]
622+
inputs = tuple(sorted(parsed_inputs_to_values.keys(), key=id))
623+
cache_key = (inputs, tuple(kwargs.items()))
624+
try:
625+
fn = self._fn_cache[cache_key]
626+
except (KeyError, TypeError):
627+
fn = None
620628

621-
rval = self._fn_cache[inputs](*args)
629+
if fn is None:
630+
fn = function(inputs, self, **kwargs)
631+
try:
632+
self._fn_cache[cache_key] = fn
633+
except TypeError as exc:
634+
warnings.warn(
635+
"Keyword arguments could not be used to create a cache key for the underlying variable. "
636+
f"A function will be recompiled on every call with such keyword arguments.\n{exc}"
637+
)
622638

623-
return rval
639+
args = [parsed_inputs_to_values[param] for param in inputs]
640+
return fn(*args)
624641

625642
def __getstate__(self):
626643
d = self.__dict__.copy()

tests/graph/test_basic.py

+20
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from pytensor import shared
88
from pytensor import tensor as pt
9+
from pytensor.compile import UnusedInputError
910
from pytensor.graph.basic import (
1011
Apply,
1112
NominalVariable,
@@ -30,6 +31,7 @@
3031
)
3132
from pytensor.graph.op import Op
3233
from pytensor.graph.type import Type
34+
from pytensor.tensor import constant
3335
from pytensor.tensor.math import max_and_argmax
3436
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
3537
from pytensor.tensor.type_other import NoneConst
@@ -359,6 +361,24 @@ def test_eval_with_strings_no_match(self):
359361
with pytest.raises(Exception, match="o not found in graph"):
360362
t.eval({"o": 1})
361363

364+
def test_eval_kwargs(self):
365+
with pytest.raises(UnusedInputError):
366+
self.w.eval({self.z: 3, self.x: 2.5})
367+
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0
368+
369+
@pytest.mark.filterwarnings("error")
370+
def test_eval_unashable_kwargs(self):
371+
y_repl = constant(2.0, dtype="floatX")
372+
373+
assert self.w.eval({self.x: 1.0}, givens=((self.y, y_repl),)) == 6.0
374+
375+
with pytest.warns(
376+
UserWarning,
377+
match="Keyword arguments could not be used to create a cache key",
378+
):
379+
# givens dict is not hashable
380+
assert self.w.eval({self.x: 1.0}, givens={self.y: y_repl}) == 6.0
381+
362382

363383
class TestAutoName:
364384
def test_auto_name(self):

0 commit comments

Comments
 (0)