@@ -555,13 +555,20 @@ def get_parents(self):
555
555
return [self .owner ]
556
556
return []
557
557
558
- def eval (self , inputs_to_values = None ):
559
- r"""Evaluate the `Variable`.
558
+ def eval (
559
+ self ,
560
+ inputs_to_values : dict [Union ["Variable" , str ], Any ] | None = None ,
561
+ ** kwargs ,
562
+ ):
563
+ r"""Evaluate the `Variable` given a set of values for its inputs.
560
564
561
565
Parameters
562
566
----------
563
567
inputs_to_values :
564
- A dictionary mapping PyTensor `Variable`\s to values.
568
+ A dictionary mapping PyTensor `Variable`\s or names to values.
569
+ Not needed if variable has no required inputs.
570
+ kwargs :
571
+ Optional keyword arguments to pass to the underlying `pytensor.function`
565
572
566
573
Examples
567
574
--------
@@ -591,10 +598,7 @@ def eval(self, inputs_to_values=None):
591
598
"""
592
599
from pytensor .compile .function import function
593
600
594
- if inputs_to_values is None :
595
- inputs_to_values = {}
596
-
597
- def convert_string_keys_to_variables (input_to_values ):
601
+ def convert_string_keys_to_variables (inputs_to_values ) -> dict ["Variable" , Any ]:
598
602
new_input_to_values = {}
599
603
for key , value in inputs_to_values .items ():
600
604
if isinstance (key , str ):
@@ -608,19 +612,32 @@ def convert_string_keys_to_variables(input_to_values):
608
612
new_input_to_values [key ] = value
609
613
return new_input_to_values
610
614
611
- inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
615
+ parsed_inputs_to_values : dict [Variable , Any ] = {}
616
+ if inputs_to_values is not None :
617
+ parsed_inputs_to_values = convert_string_keys_to_variables (inputs_to_values )
612
618
613
619
if not hasattr (self , "_fn_cache" ):
614
- self ._fn_cache = dict ()
620
+ self ._fn_cache : dict = dict ()
615
621
616
- inputs = tuple (sorted (inputs_to_values .keys (), key = id ))
617
- if inputs not in self ._fn_cache :
618
- self ._fn_cache [inputs ] = function (inputs , self )
619
- args = [inputs_to_values [param ] for param in inputs ]
622
+ inputs = tuple (sorted (parsed_inputs_to_values .keys (), key = id ))
623
+ cache_key = (inputs , tuple (kwargs .items ()))
624
+ try :
625
+ fn = self ._fn_cache [cache_key ]
626
+ except (KeyError , TypeError ):
627
+ fn = None
620
628
621
- rval = self ._fn_cache [inputs ](* args )
629
+ if fn is None :
630
+ fn = function (inputs , self , ** kwargs )
631
+ try :
632
+ self ._fn_cache [cache_key ] = fn
633
+ except TypeError as exc :
634
+ warnings .warn (
635
+ "Keyword arguments could not be used to create a cache key for the underlying variable. "
636
+ f"A function will be recompiled on every call with such keyword arguments.\n { exc } "
637
+ )
622
638
623
- return rval
639
+ args = [parsed_inputs_to_values [param ] for param in inputs ]
640
+ return fn (* args )
624
641
625
642
def __getstate__ (self ):
626
643
d = self .__dict__ .copy ()
0 commit comments