Skip to content

Commit 1f34a5f

Browse files
kyleabeauchampJunpeng Lao
authored andcommitted
Add MAP estimates for transformed and untransformed variables (#2523)
* Add MAP estimates for transformed and untransformed variables * Fix docstring * Fix lint
1 parent d01aaf0 commit 1f34a5f

File tree

3 files changed

+44
-137
lines changed

3 files changed

+44
-137
lines changed

pymc3/tests/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,19 @@ def beta_bernoulli(n=2):
157157
pm.Beta('x', 3, 1, shape=n, transform=None)
158158
pm.Bernoulli('y', 0.5)
159159
return model.test_point, model, None
160+
161+
162+
def simple_normal(bounded_prior=False):
163+
"""Simple normal for testing MLE / MAP; probes issue #2482."""
164+
x0 = 10.0
165+
sd = 1.0
166+
a, b = (9, 12) # bounds for uniform RV, need non-symmetric to reproduce issue
167+
168+
with pm.Model() as model:
169+
if bounded_prior:
170+
mu_i = pm.Uniform("mu_i", a, b)
171+
else:
172+
mu_i = pm.Flat("mu_i")
173+
pm.Normal("X_obs", mu=mu_i, sd=sd, observed=x0)
174+
175+
return model.test_point, model, None

pymc3/tests/test_tuning.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from numpy import inf
3-
from pymc3.tuning import scaling
3+
from pymc3.tuning import scaling, find_MAP
44
from . import models
55

66

@@ -14,3 +14,21 @@ def test_guess_scaling():
1414
start, model, _ = models.non_normal(n=5)
1515
a1 = scaling.guess_scaling(start, model=model)
1616
assert all((a1 > 0) & (a1 < 1e200))
17+
18+
19+
def test_mle_jacobian():
20+
"""Test MAP / MLE estimation for distributions with flat priors."""
21+
truth = 10.0 # Simple normal model should give mu=10.0
22+
23+
start, model, _ = models.simple_normal(bounded_prior=False)
24+
with model:
25+
map_estimate = find_MAP(model=model)
26+
27+
rtol = 1E-5 # this rtol should work on both floatX precisions
28+
np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)
29+
30+
start, model, _ = models.simple_normal(bounded_prior=True)
31+
with model:
32+
map_estimate = find_MAP(model=model)
33+
34+
np.testing.assert_allclose(map_estimate["mu_i"], truth, rtol=rtol)

pymc3/tuning/starting.py

Lines changed: 9 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
from numpy import isfinite, nan_to_num, logical_not
99
import pymc3 as pm
10-
import time
1110
from ..vartypes import discrete_types, typefilter
1211
from ..model import modelcontext, Point
1312
from ..theanof import inputvars
@@ -20,7 +19,7 @@
2019

2120

2221
def find_MAP(start=None, vars=None, fmin=None,
23-
return_raw=False, model=None, live_disp=False, callback=None,
22+
return_raw=False, model=None, callback=None,
2423
*args, **kwargs):
2524
"""
2625
Sets state to the local maximum a posteriori point given a model.
@@ -31,20 +30,16 @@ def find_MAP(start=None, vars=None, fmin=None,
3130
----------
3231
start : `dict` of parameter values (Defaults to `model.test_point`)
3332
vars : list
34-
List of variables to set to MAP point (Defaults to all continuous).
33+
List of variables to optimize and set to optimum (Defaults to all continuous).
3534
fmin : function
3635
Optimization algorithm (Defaults to `scipy.optimize.fmin_bfgs` unless
3736
discrete variables are specified in `vars`, then
3837
`scipy.optimize.fmin_powell` which will perform better).
3938
return_raw : Bool
4039
Whether to return extra value returned by fmin (Defaults to `False`)
4140
model : Model (optional if in `with` context)
42-
live_disp : Bool
43-
Display table tracking optimization progress when run from within
44-
an IPython notebook.
4541
callback : callable
46-
Callback function to pass to scipy optimization routine. Overrides
47-
live_disp if callback is given.
42+
Callback function to pass to scipy optimization routine.
4843
*args, **kwargs
4944
Extra args passed to fmin
5045
"""
@@ -89,7 +84,8 @@ def find_MAP(start=None, vars=None, fmin=None,
8984

9085
start = Point(start, model=model)
9186
bij = DictToArrayBijection(ArrayOrdering(vars), start)
92-
87+
logp_func = bij.mapf(model.fastlogp)
88+
x0 = bij.map(start)
9389
logp = bij.mapf(model.fastlogp_nojac)
9490
def logp_o(point):
9591
return nan_to_high(-logp(point))
@@ -100,15 +96,9 @@ def logp_o(point):
10096
def grad_logp_o(point):
10197
return nan_to_num(-dlogp(point))
10298

103-
if live_disp and callback is None:
104-
callback = Monitor(bij, logp_o, model, grad_logp_o)
105-
10699
r = fmin(logp_o, bij.map(start), fprime=grad_logp_o, callback=callback, *args, **kwargs)
107100
compute_gradient = True
108101
else:
109-
if live_disp and callback is None:
110-
callback = Monitor(bij, logp_o, dlogp=None)
111-
112102
# Check to see if minimization function uses a starting value
113103
if 'x0' in getargspec(fmin).args:
114104
r = fmin(logp_o, bij.map(start), callback=callback, *args, **kwargs)
@@ -121,12 +111,6 @@ def grad_logp_o(point):
121111
else:
122112
mx0 = r
123113

124-
if live_disp:
125-
try:
126-
callback.update(mx0)
127-
except:
128-
pass
129-
130114
mx = bij.rmap(mx0)
131115

132116
allfinite_mx0 = allfinite(mx0)
@@ -171,13 +155,16 @@ def message(name, values):
171155
"density. 2) your distribution logp's are " +
172156
"properly specified. Specific issues: \n" +
173157
specific_errors)
174-
mx = {v.name: mx[v.name].astype(v.dtype) for v in model.vars}
158+
159+
vars = model.unobserved_RVs
160+
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(mx))}
175161

176162
if return_raw:
177163
return mx, r
178164
else:
179165
return mx
180166

167+
181168
def allfinite(x):
182169
return np.all(isfinite(x))
183170

@@ -192,120 +179,6 @@ def allinmodel(vars, model):
192179
raise ValueError("Some variables not in the model: " + str(notin))
193180

194181

195-
196-
class Monitor(object):
197-
def __init__(self, bij, logp, model, dlogp=None):
198-
try:
199-
from IPython.display import display
200-
from ipywidgets import HTML, VBox, HBox, FlexBox
201-
self.prog_table = HTML(width='100%')
202-
self.param_table = HTML(width='100%')
203-
r_col = VBox(children=[self.param_table], padding=3, width='100%')
204-
l_col = HBox(children=[self.prog_table], padding=3, width='25%')
205-
self.hor_align = FlexBox(children = [l_col, r_col], width='100%', orientation='vertical')
206-
display(self.hor_align)
207-
self.using_notebook = True
208-
self.update_interval = 1
209-
except:
210-
self.using_notebook = False
211-
self.update_interval = 2
212-
213-
self.iters = 0
214-
self.bij = bij
215-
self.model = model
216-
self.fn = model.fastfn(model.unobserved_RVs)
217-
self.logp = logp
218-
self.dlogp = dlogp
219-
self.t_initial = time.time()
220-
self.t0 = self.t_initial
221-
self.paramtable = {}
222-
223-
def __call__(self, x):
224-
self.iters += 1
225-
if time.time() - self.t0 > self.update_interval or self.iters == 1:
226-
self.update(x)
227-
228-
def update(self, x):
229-
self._update_progtable(x)
230-
self._update_paramtable(x)
231-
if self.using_notebook:
232-
self._display_notebook()
233-
self.t0 = time.time()
234-
235-
def _update_progtable(self, x):
236-
s = time.time() - self.t_initial
237-
hours, remainder = divmod(int(s), 3600)
238-
minutes, seconds = divmod(remainder, 60)
239-
self.t_elapsed = "{:2d}h{:2d}m{:2d}s".format(hours, minutes, seconds)
240-
self.logpost = -1.0*np.float(self.logp(x))
241-
self.dlogpost = np.linalg.norm(self.dlogp(x))
242-
243-
def _update_paramtable(self, x):
244-
var_state = self.fn(self.bij.rmap(x))
245-
for var, val in zip(self.model.unobserved_RVs, var_state):
246-
if not var.name.endswith("_"):
247-
valstr = format_values(val)
248-
self.paramtable[var.name] = {"size": val.size, "valstr": valstr}
249-
250-
def _display_notebook(self):
251-
## Progress table
252-
html = r"""<style type="text/css">
253-
table { border-collapse:collapse }
254-
.tg {border-collapse:collapse;border-spacing:0;border:none;}
255-
.tg td{font-family:Arial, sans-serif;font-size:14px;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;}
256-
.tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal; background-color:#0E688A;color:#ffffff;}
257-
.tg .tg-vkoh{white-space:pre;font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important; background-color:#ffffff;color:#000000}
258-
.tg .tg-suao{font-weight:bold;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
259-
"""
260-
html += r"""
261-
</style>
262-
<table class="tg" style="undefined;">
263-
<col width="400px" />
264-
<tr>
265-
<th class= "tg-vkoh">Time Elapsed: {:s}</th>
266-
</tr>
267-
<tr>
268-
<th class= "tg-vkoh">Iteration: {:d}</th>
269-
</tr>
270-
<tr>
271-
<th class= "tg-vkoh">Log Posterior: {:.3f}</th>
272-
</tr>
273-
""".format(self.t_elapsed, self.iters, self.logpost)
274-
if self.dlogp is not None:
275-
html += r"""
276-
<tr>
277-
<th class= "tg-vkoh">||grad||: {:.3f}</th>
278-
</tr>""".format(self.dlogpost)
279-
html += "</table>"
280-
self.prog_table.value = html
281-
## Parameter table
282-
html = r"""<style type="text/css">
283-
.tg .tg-bgft{font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
284-
.tg td{font-family:Arial, sans-serif;font-size:12px;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;color:#333;background-color:#fff;word-wrap: break-word;}
285-
.tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;background-color:#0E688A;color:#ffffff;}
286-
</style>
287-
<table class="tg" style="undefined;">
288-
<col width="130px" />
289-
<col width="50px" />
290-
<col width="600px" />
291-
<tr>
292-
<th class="tg">Parameter</th>
293-
<th class="tg">Size</th>
294-
<th class="tg">Current Value</th>
295-
</tr>
296-
"""
297-
for var, values in self.paramtable.items():
298-
html += r"""
299-
<tr>
300-
<td class="tg-bgft">{:s}</td>
301-
<td class="tg-vkoh">{:d}</td>
302-
<td class="tg-vkoh">{:s}</td>
303-
</tr>
304-
""".format(var, values["size"], values["valstr"])
305-
html += "</table>"
306-
self.param_table.value = html
307-
308-
309182
def format_values(val):
310183
fmt = "{:8.3f}"
311184
if val.size == 1:

0 commit comments

Comments
 (0)