@@ -935,7 +935,9 @@ def test_point(self) -> Dict[str, np.ndarray]:
935
935
@property
936
936
def initial_point (self ) -> Dict [str , np .ndarray ]:
937
937
"""Maps free variable names to transformed, numeric initial values."""
938
- if set (self ._initial_point_cache ) != {get_var_name (k ) for k in self .initial_values }:
938
+ if set (self ._initial_point_cache ) != {
939
+ get_var_name (self .rvs_to_values [k ]) for k in self .initial_values
940
+ }:
939
941
return self .recompute_initial_point ()
940
942
return self ._initial_point_cache
941
943
@@ -949,40 +951,40 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
949
951
"""
950
952
numeric_initvals = {}
951
953
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
952
- for rv_value , initval in self .initial_values .items ():
953
- rv_var = self .values_to_rvs [ rv_value ]
954
+ for rv_var , initval in self .initial_values .items ():
955
+ rv_value = self .rvs_to_values [ rv_var ]
954
956
transform = getattr (rv_value .tag , "transform" , None )
955
957
if isinstance (initval , np .ndarray ) and transform is None :
956
958
# Only untransformed, numeric initvals can be taken as they are.
957
- numeric_initvals [rv_value ] = initval
959
+ numeric_initvals [rv_var ] = initval
958
960
else :
959
961
# Evaluate initvals that are None, symbolic or need to be transformed.
960
962
# They can depend on other initvals from higher up in the graph,
961
963
# which are therefore fed to the evaluation as "givens".
962
964
test_value = getattr (rv_var .tag , "test_value" , None )
963
- numeric_initvals [rv_value ] = self ._eval_initval (
965
+ numeric_initvals [rv_var ] = self ._eval_initval (
964
966
rv_var , initval , test_value , transform , given = numeric_initvals
965
967
)
966
968
967
969
# Cache the evaluation results for next time.
968
- self ._initial_point_cache = Point (list (numeric_initvals .items ()), model = self )
970
+ self ._initial_point_cache = Point (
971
+ [(self .rvs_to_values [k ], v ) for k , v in numeric_initvals .items ()], model = self
972
+ )
969
973
return self ._initial_point_cache
970
974
971
975
@property
972
976
def initial_values (self ) -> Dict [TensorVariable , Optional [Union [np .ndarray , Variable ]]]:
973
977
"""Maps transformed variables to initial value placeholders.
974
978
975
- ⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
976
- For a name-based dictionary use the `get_initial_point()` method.
979
+ Keys are the random variables (as returned by e.g. ``pm.Uniform()``).
977
980
"""
978
981
return self ._initial_values
979
982
980
983
def set_initval (self , rv_var , initval ):
981
984
if initval is not None :
982
985
initval = rv_var .type .filter (initval )
983
986
984
- rv_value_var = self .rvs_to_values [rv_var ]
985
- self .initial_values [rv_value_var ] = initval
987
+ self .initial_values [rv_var ] = initval
986
988
987
989
def _eval_initval (
988
990
self ,
@@ -1031,16 +1033,16 @@ def _eval_initval(
1031
1033
value = rv_var
1032
1034
rv_var = at .as_tensor_variable (transform .forward (rv_var , value ))
1033
1035
1034
- def initval_to_rvval (value_var , value ):
1035
- rv_var = self .values_to_rvs [ value_var ]
1036
+ def initval_to_rvval (rv_var , value ):
1037
+ value_var = self .rvs_to_values [ rv_var ]
1036
1038
initval = value_var .type .make_constant (value )
1037
1039
transform = getattr (value_var .tag , "transform" , None )
1038
1040
if transform :
1039
1041
return transform .backward (rv_var , initval )
1040
1042
else :
1041
1043
return initval
1042
1044
1043
- givens = {self . values_to_rvs [ k ] : initval_to_rvval (k , v ) for k , v in given .items ()}
1045
+ givens = {k : initval_to_rvval (k , v ) for k , v in given .items ()}
1044
1046
initval_fn = aesara .function ([], rv_var , mode = mode , givens = givens , on_unused_input = "ignore" )
1045
1047
try :
1046
1048
initval = initval_fn ()
0 commit comments