Skip to content

Commit 5f2ee98

Browse files
Stop caching initial points and wrap function creation
1 parent e761327 commit 5f2ee98

File tree

1 file changed

+44
-27
lines changed

1 file changed

+44
-27
lines changed

pymc3/model.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import (
2323
TYPE_CHECKING,
2424
Any,
25+
Callable,
2526
Dict,
2627
List,
2728
Optional,
@@ -650,7 +651,6 @@ def __init__(
650651
# The sequence of model-generated RNGs
651652
self.rng_seq = []
652653
self._initial_values = {}
653-
self._initial_point_cache = {}
654654

655655
if self.parent is not None:
656656
self.named_vars = treedict(parent=self.parent.named_vars)
@@ -935,42 +935,59 @@ def test_point(self) -> Dict[str, np.ndarray]:
935935
@property
936936
def initial_point(self) -> Dict[str, np.ndarray]:
937937
"""Maps free variable names to transformed, numeric initial values."""
938-
if set(self._initial_point_cache) != {
939-
get_var_name(self.rvs_to_values[k]) for k in self.initial_values
940-
}:
941-
return self.recompute_initial_point()
942-
return self._initial_point_cache
938+
return self.recompute_initial_point()
943939

944940
def recompute_initial_point(self) -> Dict[str, np.ndarray]:
941+
"""Recomputes the initial point of the model.
942+
943+
Returns
944+
-------
945+
ip : dict
946+
Maps names of transformed variables to numeric initial values in the transformed space.
947+
"""
948+
fn = self.make_initial_point_fn()
949+
return Point(fn(), model=self)
950+
951+
def make_initial_point_fn(
952+
self,
953+
*,
954+
return_transformed: bool = True,
955+
) -> Callable[[], Dict[TensorVariable, np.ndarray]]:
945956
"""Recomputes numeric initial values for all free model variables.
946957
958+
Parameters
959+
----------
960+
return_transformed : bool
961+
Switches between returning the dictionary based on RV vars or RV value vars as keys.
962+
947963
Returns
948964
-------
949965
initial_point : dict
950966
Maps transformed free variable names to transformed, numeric initial values.
951967
"""
952-
numeric_initvals = {}
953-
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
954-
for rv_var, initval in self.initial_values.items():
955-
rv_value = self.rvs_to_values[rv_var]
956-
transform = getattr(rv_value.tag, "transform", None)
957-
if isinstance(initval, np.ndarray) and transform is None:
958-
# Only untransformed, numeric initvals can be taken as they are.
959-
numeric_initvals[rv_var] = initval
960-
else:
961-
# Evaluate initvals that are None, symbolic or need to be transformed.
962-
# They can depend on other initvals from higher up in the graph,
963-
# which are therefore fed to the evaluation as "givens".
964-
test_value = getattr(rv_var.tag, "test_value", None)
965-
numeric_initvals[rv_var] = self._eval_initval(
966-
rv_var, initval, test_value, transform, given=numeric_initvals
967-
)
968968

969-
# Cache the evaluation results for next time.
970-
self._initial_point_cache = Point(
971-
[(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self
972-
)
973-
return self._initial_point_cache
969+
def fn():
970+
numeric_initvals = {}
971+
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
972+
for rv_var, initval in self.initial_values.items():
973+
rv_value = self.rvs_to_values[rv_var]
974+
transform = getattr(rv_value.tag, "transform", None)
975+
if isinstance(initval, np.ndarray) and transform is None:
976+
# Only untransformed, numeric initvals can be taken as they are.
977+
numeric_initvals[rv_var] = initval
978+
else:
979+
# Evaluate initvals that are None, symbolic or need to be transformed.
980+
# They can depend on other initvals from higher up in the graph,
981+
# which are therefore fed to the evaluation as "givens".
982+
test_value = getattr(rv_var.tag, "test_value", None)
983+
numeric_initvals[rv_var] = self._eval_initval(
984+
rv_var, initval, test_value, transform, given=numeric_initvals
985+
)
986+
if return_transformed:
987+
return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()}
988+
return numeric_initvals
989+
990+
return fn
974991

975992
@property
976993
def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]:

0 commit comments

Comments
 (0)