|
22 | 22 | from typing import (
|
23 | 23 | TYPE_CHECKING,
|
24 | 24 | Any,
|
| 25 | + Callable, |
25 | 26 | Dict,
|
26 | 27 | List,
|
27 | 28 | Optional,
|
@@ -650,7 +651,6 @@ def __init__(
|
650 | 651 | # The sequence of model-generated RNGs
|
651 | 652 | self.rng_seq = []
|
652 | 653 | self._initial_values = {}
|
653 |
| - self._initial_point_cache = {} |
654 | 654 |
|
655 | 655 | if self.parent is not None:
|
656 | 656 | self.named_vars = treedict(parent=self.parent.named_vars)
|
@@ -935,42 +935,59 @@ def test_point(self) -> Dict[str, np.ndarray]:
|
935 | 935 | @property
|
936 | 936 | def initial_point(self) -> Dict[str, np.ndarray]:
|
937 | 937 | """Maps free variable names to transformed, numeric initial values."""
|
938 |
| - if set(self._initial_point_cache) != { |
939 |
| - get_var_name(self.rvs_to_values[k]) for k in self.initial_values |
940 |
| - }: |
941 |
| - return self.recompute_initial_point() |
942 |
| - return self._initial_point_cache |
| 938 | + return self.recompute_initial_point() |
943 | 939 |
|
944 | 940 | def recompute_initial_point(self) -> Dict[str, np.ndarray]:
|
| 941 | + """Recomputes the initial point of the model. |
| 942 | +
|
| 943 | + Returns |
| 944 | + ------- |
| 945 | + ip : dict |
| 946 | + Maps names of transformed variables to numeric initial values in the transformed space. |
| 947 | + """ |
| 948 | + fn = self.make_initial_point_fn() |
| 949 | + return Point(fn(), model=self) |
| 950 | + |
| 951 | + def make_initial_point_fn( |
| 952 | + self, |
| 953 | + *, |
| 954 | + return_transformed: bool = True, |
| 955 | + ) -> Callable[[], Dict[TensorVariable, np.ndarray]]: |
945 | 956 | """Recomputes numeric initial values for all free model variables.
|
946 | 957 |
|
| 958 | + Parameters |
| 959 | + ---------- |
| 960 | + return_transformed : bool |
| 961 | + Switches between returning the dictionary based on RV vars or RV value vars as keys. |
| 962 | +
|
947 | 963 | Returns
|
948 | 964 | -------
|
949 | 965 | initial_point : dict
|
950 | 966 | Maps transformed free variable names to transformed, numeric initial values.
|
951 | 967 | """
|
952 |
| - numeric_initvals = {} |
953 |
| - # The entries in `initial_values` are already in topological order and can be evaluated one by one. |
954 |
| - for rv_var, initval in self.initial_values.items(): |
955 |
| - rv_value = self.rvs_to_values[rv_var] |
956 |
| - transform = getattr(rv_value.tag, "transform", None) |
957 |
| - if isinstance(initval, np.ndarray) and transform is None: |
958 |
| - # Only untransformed, numeric initvals can be taken as they are. |
959 |
| - numeric_initvals[rv_var] = initval |
960 |
| - else: |
961 |
| - # Evaluate initvals that are None, symbolic or need to be transformed. |
962 |
| - # They can depend on other initvals from higher up in the graph, |
963 |
| - # which are therefore fed to the evaluation as "givens". |
964 |
| - test_value = getattr(rv_var.tag, "test_value", None) |
965 |
| - numeric_initvals[rv_var] = self._eval_initval( |
966 |
| - rv_var, initval, test_value, transform, given=numeric_initvals |
967 |
| - ) |
968 | 968 |
|
969 |
| - # Cache the evaluation results for next time. |
970 |
| - self._initial_point_cache = Point( |
971 |
| - [(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self |
972 |
| - ) |
973 |
| - return self._initial_point_cache |
| 969 | + def fn(): |
| 970 | + numeric_initvals = {} |
| 971 | + # The entries in `initial_values` are already in topological order and can be evaluated one by one. |
| 972 | + for rv_var, initval in self.initial_values.items(): |
| 973 | + rv_value = self.rvs_to_values[rv_var] |
| 974 | + transform = getattr(rv_value.tag, "transform", None) |
| 975 | + if isinstance(initval, np.ndarray) and transform is None: |
| 976 | + # Only untransformed, numeric initvals can be taken as they are. |
| 977 | + numeric_initvals[rv_var] = initval |
| 978 | + else: |
| 979 | + # Evaluate initvals that are None, symbolic or need to be transformed. |
| 980 | + # They can depend on other initvals from higher up in the graph, |
| 981 | + # which are therefore fed to the evaluation as "givens". |
| 982 | + test_value = getattr(rv_var.tag, "test_value", None) |
| 983 | + numeric_initvals[rv_var] = self._eval_initval( |
| 984 | + rv_var, initval, test_value, transform, given=numeric_initvals |
| 985 | + ) |
| 986 | + if return_transformed: |
| 987 | + return {self.rvs_to_values[k]: v for k, v in numeric_initvals.items()} |
| 988 | + return numeric_initvals |
| 989 | + |
| 990 | + return fn |
974 | 991 |
|
975 | 992 | @property
|
976 | 993 | def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]:
|
|
0 commit comments