@@ -650,6 +650,7 @@ def __init__(
650
650
# The sequence of model-generated RNGs
651
651
self .rng_seq = []
652
652
self ._initial_values = {}
653
+ self ._initial_point_cache = {}
653
654
654
655
if self .parent is not None :
655
656
self .named_vars = treedict (parent = self .parent .named_vars )
@@ -926,15 +927,28 @@ def cont_vars(self):
926
927
def test_point (self ) -> Dict [str , np .ndarray ]:
927
928
"""Deprecated alias for `Model.initial_point`."""
928
929
warnings .warn (
929
- "`Model.test_point` has been deprecated. Use `Model.initial_point` instead ." ,
930
+ "`Model.test_point` has been deprecated. Use `Model.initial_point` or `Model.recompute_initial_point()` ." ,
930
931
DeprecationWarning ,
931
932
)
932
933
return self .initial_point
933
934
934
935
@property
935
936
def initial_point (self ) -> Dict [str , np .ndarray ]:
936
- """Maps names of variables to initial values."""
937
- return Point (list (self .initial_values .items ()), model = self )
937
+ """Maps free variable names to transformed, numeric initial values."""
938
+ if set (self ._initial_point_cache ) != {get_var_name (k ) for k in self .initial_values }:
939
+ return self .recompute_initial_point ()
940
+ return self ._initial_point_cache
941
+
942
+ def recompute_initial_point (self ) -> Dict [str , np .ndarray ]:
943
+ """Recomputes numeric initial values for all free model variables.
944
+
945
+ Returns
946
+ -------
947
+ initial_point : dict
948
+ Maps free variable names to transformed, numeric initial values.
949
+ """
950
+ self ._initial_point_cache = Point (list (self .initial_values .items ()), model = self )
951
+ return self ._initial_point_cache
938
952
939
953
@property
940
954
def initial_values (self ) -> Dict [TensorVariable , np .ndarray ]:
0 commit comments