Skip to content

Commit b339ca5

Browse files
michaelosthegetwiecki
authored andcommitted
Add initial point caching and recompute_initial_point method
Prepares the API signature for lazy initval evaluation.
1 parent c4da4e3 commit b339ca5

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

pymc3/model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def __init__(
650650
# The sequence of model-generated RNGs
651651
self.rng_seq = []
652652
self._initial_values = {}
653+
self._initial_point_cache = {}
653654

654655
if self.parent is not None:
655656
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -926,15 +927,28 @@ def cont_vars(self):
926927
def test_point(self) -> Dict[str, np.ndarray]:
927928
"""Deprecated alias for `Model.initial_point`."""
928929
warnings.warn(
929-
"`Model.test_point` has been deprecated. Use `Model.initial_point` instead.",
930+
"`Model.test_point` has been deprecated. Use `Model.initial_point` or `Model.recompute_initial_point()`.",
930931
DeprecationWarning,
931932
)
932933
return self.initial_point
933934

934935
@property
935936
def initial_point(self) -> Dict[str, np.ndarray]:
936-
"""Maps names of variables to initial values."""
937-
return Point(list(self.initial_values.items()), model=self)
937+
"""Maps free variable names to transformed, numeric initial values."""
938+
if set(self._initial_point_cache) != {get_var_name(k) for k in self.initial_values}:
939+
return self.recompute_initial_point()
940+
return self._initial_point_cache
941+
942+
def recompute_initial_point(self) -> Dict[str, np.ndarray]:
943+
"""Recomputes numeric initial values for all free model variables.
944+
945+
Returns
946+
-------
947+
initial_point : dict
948+
Maps free variable names to transformed, numeric initial values.
949+
"""
950+
self._initial_point_cache = Point(list(self.initial_values.items()), model=self)
951+
return self._initial_point_cache
938952

939953
@property
940954
def initial_values(self) -> Dict[TensorVariable, np.ndarray]:

0 commit comments

Comments
 (0)