Skip to content

Commit fd7fc60

Browse files
Fix make_initial_point_expression recursion bug
Co-authored-by: Michael Osthege <[email protected]>
1 parent 275c145 commit fd7fc60

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

pymc/initial_point.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def make_initial_point_expression(
291291
jitter.name = f"{variable.name}_jitter"
292292
value = value + jitter
293293

294+
value = value.astype(variable.dtype)
294295
initial_values_transformed.append(value)
295296

296297
if transform is not None:
@@ -310,18 +311,17 @@ def make_initial_point_expression(
310311
initial_values_clone = copy_graph.outputs[n_variables:-n_variables]
311312
initial_values_transformed_clone = copy_graph.outputs[-n_variables:]
312313

313-
# In the order the variables were created, replace each previous variable
314-
# with the init_point for that variable.
315-
initial_values = []
316-
initial_values_transformed = []
317-
318-
for i in range(n_variables):
319-
outputs = [initial_values_clone[i], initial_values_transformed_clone[i]]
320-
graph = FunctionGraph(outputs=outputs, clone=False)
321-
graph.replace_all(zip(free_rvs_clone[:i], initial_values), import_missing=True)
322-
initial_values.append(graph.outputs[0])
323-
initial_values_transformed.append(graph.outputs[1])
324-
325-
if return_transformed:
326-
return initial_values_transformed
327-
return initial_values
314+
# We now replace all rvs by the respective initial_point expressions
315+
# in the constrained (untransformed) space. We do this in reverse topological
316+
# order, so that later nodes do not reintroduce expressions with earlier
317+
# rvs that would need to once again be replaced by their initial_points
318+
graph = FunctionGraph(outputs=free_rvs_clone, clone=False)
319+
replacements = reversed(list(zip(free_rvs_clone, initial_values_clone)))
320+
graph.replace_all(replacements, import_missing=True)
321+
322+
if not return_transformed:
323+
return graph.outputs
324+
# Because the unconstrained (transformed) expressions are a subgraph of the
325+
# constrained initial point they were also automatically updated inplace
326+
# when calling graph.replace_all above, so we don't need to do anything else
327+
return initial_values_transformed_clone

pymc/tests/test_initial_point.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,46 @@ def test_make_initial_point_fns_per_chain_checks_kwargs(self):
7474
def test_dependent_initvals(self):
7575
with pm.Model() as pmodel:
7676
L = pm.Uniform("L", 0, 1, initval=0.5)
77-
B = pm.Uniform("B", lower=L, upper=2, initval=1.25)
77+
U = pm.Uniform("U", lower=9, upper=10, initval=9.5)
78+
B1 = pm.Uniform("B1", lower=L, upper=U, initval=5)
79+
B2 = pm.Uniform("B2", lower=L, upper=U, initval=(L + U) / 2)
7880
ip = pmodel.recompute_initial_point(seed=0)
7981
assert ip["L_interval__"] == 0
80-
assert ip["B_interval__"] == 0
82+
assert ip["U_interval__"] == 0
83+
assert ip["B1_interval__"] == 0
84+
assert ip["B2_interval__"] == 0
8185

8286
# Modify initval of L and re-evaluate
83-
pmodel.initial_values[L] = 0.9
87+
pmodel.initial_values[U] = 9.9
8488
ip = pmodel.recompute_initial_point(seed=0)
85-
assert ip["B_interval__"] < 0
89+
assert ip["B1_interval__"] < 0
90+
assert ip["B2_interval__"] == 0
8691
pass
8792

93+
def test_nested_initvals(self):
94+
# See issue #5168
95+
with pm.Model() as pmodel:
96+
one = pm.LogNormal("one", mu=np.log(1), sd=1e-5, initval="prior")
97+
two = pm.Lognormal("two", mu=np.log(one * 2), sd=1e-5, initval="prior")
98+
three = pm.LogNormal("three", mu=np.log(two * 2), sd=1e-5, initval="prior")
99+
four = pm.LogNormal("four", mu=np.log(three * 2), sd=1e-5, initval="prior")
100+
five = pm.LogNormal("five", mu=np.log(four * 2), sd=1e-5, initval="prior")
101+
six = pm.LogNormal("six", mu=np.log(five * 2), sd=1e-5, initval="prior")
102+
103+
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
104+
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 8, 16, 32], rtol=1e-3)
105+
106+
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
107+
assert np.allclose(ip_vals, [1, 2, 4, 8, 16, 32], rtol=1e-3)
108+
109+
pmodel.initial_values[four] = 1
110+
111+
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=True)(0).values())
112+
assert np.allclose(np.exp(ip_vals), [1, 2, 4, 1, 2, 4], rtol=1e-3)
113+
114+
ip_vals = list(make_initial_point_fn(model=pmodel, return_transformed=False)(0).values())
115+
assert np.allclose(ip_vals, [1, 2, 4, 1, 2, 4], rtol=1e-3)
116+
88117
def test_initval_resizing(self):
89118
with pm.Model() as pmodel:
90119
data = aesara.shared(np.arange(4))

0 commit comments

Comments
 (0)