Skip to content

Commit e761327

Browse files
Manage initial values by RV var instead of RV value var
1 parent 3aa5c54 commit e761327

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

pymc3/model.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,9 @@ def test_point(self) -> Dict[str, np.ndarray]:
935935
@property
936936
def initial_point(self) -> Dict[str, np.ndarray]:
937937
"""Maps free variable names to transformed, numeric initial values."""
938-
if set(self._initial_point_cache) != {get_var_name(k) for k in self.initial_values}:
938+
if set(self._initial_point_cache) != {
939+
get_var_name(self.rvs_to_values[k]) for k in self.initial_values
940+
}:
939941
return self.recompute_initial_point()
940942
return self._initial_point_cache
941943

@@ -949,40 +951,40 @@ def recompute_initial_point(self) -> Dict[str, np.ndarray]:
949951
"""
950952
numeric_initvals = {}
951953
# The entries in `initial_values` are already in topological order and can be evaluated one by one.
952-
for rv_value, initval in self.initial_values.items():
953-
rv_var = self.values_to_rvs[rv_value]
954+
for rv_var, initval in self.initial_values.items():
955+
rv_value = self.rvs_to_values[rv_var]
954956
transform = getattr(rv_value.tag, "transform", None)
955957
if isinstance(initval, np.ndarray) and transform is None:
956958
# Only untransformed, numeric initvals can be taken as they are.
957-
numeric_initvals[rv_value] = initval
959+
numeric_initvals[rv_var] = initval
958960
else:
959961
# Evaluate initvals that are None, symbolic or need to be transformed.
960962
# They can depend on other initvals from higher up in the graph,
961963
# which are therefore fed to the evaluation as "givens".
962964
test_value = getattr(rv_var.tag, "test_value", None)
963-
numeric_initvals[rv_value] = self._eval_initval(
965+
numeric_initvals[rv_var] = self._eval_initval(
964966
rv_var, initval, test_value, transform, given=numeric_initvals
965967
)
966968

967969
# Cache the evaluation results for next time.
968-
self._initial_point_cache = Point(list(numeric_initvals.items()), model=self)
970+
self._initial_point_cache = Point(
971+
[(self.rvs_to_values[k], v) for k, v in numeric_initvals.items()], model=self
972+
)
969973
return self._initial_point_cache
970974

971975
@property
972976
def initial_values(self) -> Dict[TensorVariable, Optional[Union[np.ndarray, Variable]]]:
973977
"""Maps transformed variables to initial value placeholders.
974978
975-
⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
976-
For a name-based dictionary use the `get_initial_point()` method.
979+
Keys are the random variables (as returned by e.g. ``pm.Uniform()``).
977980
"""
978981
return self._initial_values
979982

980983
def set_initval(self, rv_var, initval):
981984
if initval is not None:
982985
initval = rv_var.type.filter(initval)
983986

984-
rv_value_var = self.rvs_to_values[rv_var]
985-
self.initial_values[rv_value_var] = initval
987+
self.initial_values[rv_var] = initval
986988

987989
def _eval_initval(
988990
self,
@@ -1031,16 +1033,16 @@ def _eval_initval(
10311033
value = rv_var
10321034
rv_var = at.as_tensor_variable(transform.forward(rv_var, value))
10331035

1034-
def initval_to_rvval(value_var, value):
1035-
rv_var = self.values_to_rvs[value_var]
1036+
def initval_to_rvval(rv_var, value):
1037+
value_var = self.rvs_to_values[rv_var]
10361038
initval = value_var.type.make_constant(value)
10371039
transform = getattr(value_var.tag, "transform", None)
10381040
if transform:
10391041
return transform.backward(rv_var, initval)
10401042
else:
10411043
return initval
10421044

1043-
givens = {self.values_to_rvs[k]: initval_to_rvval(k, v) for k, v in given.items()}
1045+
givens = {k: initval_to_rvval(k, v) for k, v in given.items()}
10441046
initval_fn = aesara.function([], rv_var, mode=mode, givens=givens, on_unused_input="ignore")
10451047
try:
10461048
initval = initval_fn()

pymc3/tests/test_initvals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_dependent_initvals(self):
9393
assert ip["B_interval__"] == 0
9494

9595
# Modify initval of L and re-evaluate
96-
pmodel.initial_values[pmodel.rvs_to_values[L]] = 0.9
96+
pmodel.initial_values[L] = 0.9
9797
ip = pmodel.recompute_initial_point()
9898
assert ip["B_interval__"] < 0
9999
pass

pymc3/tests/test_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -514,11 +514,11 @@ def test_initial_point():
514514
with model:
515515
y = pm.Normal("y", initval=y_initval)
516516

517-
assert model.rvs_to_values[a] in model.initial_values
518-
assert model.rvs_to_values[x] in model.initial_values
519-
assert model.initial_values[b_value_var] == b_initval
517+
assert a in model.initial_values
518+
assert x in model.initial_values
519+
assert model.initial_values[b] == b_initval
520520
assert model.recompute_initial_point()["b_interval__"] == b_initval_trans
521-
assert model.initial_values[model.rvs_to_values[y]] == y_initval
521+
assert model.initial_values[y] == y_initval
522522

523523

524524
def test_point_logps():
@@ -641,17 +641,17 @@ def test_set_initval():
641641
alpha = pm.HalfNormal("alpha", initval=100)
642642
value = pm.NegativeBinomial("value", mu=mu, alpha=alpha)
643643

644-
assert np.array_equal(model.initial_values[model.rvs_to_values[mu]], np.array([[100.0]]))
645-
np.testing.assert_array_equal(model.initial_values[model.rvs_to_values[alpha]], np.array(100))
646-
assert model.initial_values[model.rvs_to_values[value]] is None
644+
assert np.array_equal(model.initial_values[mu], np.array([[100.0]]))
645+
np.testing.assert_array_equal(model.initial_values[alpha], np.array(100))
646+
assert model.initial_values[value] is None
647647

648648
# `Flat` cannot be sampled, so let's make sure that doesn't break initial
649649
# value computations
650650
with pm.Model() as model:
651651
x = pm.Flat("x")
652652
y = pm.Normal("y", x, 1)
653653

654-
assert model.rvs_to_values[y] in model.initial_values
654+
assert y in model.initial_values
655655

656656

657657
def test_datalogpt_multiple_shapes():

0 commit comments

Comments
 (0)