7
7
import numpy as np
8
8
from numpy import isfinite , nan_to_num , logical_not
9
9
import pymc3 as pm
10
- import time
11
10
from ..vartypes import discrete_types , typefilter
12
11
from ..model import modelcontext , Point
13
12
from ..theanof import inputvars
20
19
21
20
22
21
def find_MAP (start = None , vars = None , fmin = None ,
23
- return_raw = False , model = None , live_disp = False , callback = None ,
22
+ return_raw = False , model = None , callback = None ,
24
23
* args , ** kwargs ):
25
24
"""
26
25
Sets state to the local maximum a posteriori point given a model.
@@ -31,20 +30,16 @@ def find_MAP(start=None, vars=None, fmin=None,
31
30
----------
32
31
start : `dict` of parameter values (Defaults to `model.test_point`)
33
32
vars : list
34
- List of variables to set to MAP point (Defaults to all continuous).
33
+ List of variables to optimize and set to optimum (Defaults to all continuous).
35
34
fmin : function
36
35
Optimization algorithm (Defaults to `scipy.optimize.fmin_bfgs` unless
37
36
discrete variables are specified in `vars`, then
38
37
`scipy.optimize.fmin_powell` which will perform better).
39
38
return_raw : Bool
40
39
Whether to return extra value returned by fmin (Defaults to `False`)
41
40
model : Model (optional if in `with` context)
42
- live_disp : Bool
43
- Display table tracking optimization progress when run from within
44
- an IPython notebook.
45
41
callback : callable
46
- Callback function to pass to scipy optimization routine. Overrides
47
- live_disp if callback is given.
42
+ Callback function to pass to scipy optimization routine.
48
43
*args, **kwargs
49
44
Extra args passed to fmin
50
45
"""
@@ -89,7 +84,8 @@ def find_MAP(start=None, vars=None, fmin=None,
89
84
90
85
start = Point (start , model = model )
91
86
bij = DictToArrayBijection (ArrayOrdering (vars ), start )
92
-
87
+ logp_func = bij .mapf (model .fastlogp )
88
+ x0 = bij .map (start )
93
89
logp = bij .mapf (model .fastlogp_nojac )
94
90
def logp_o (point ):
95
91
return nan_to_high (- logp (point ))
@@ -100,15 +96,9 @@ def logp_o(point):
100
96
def grad_logp_o (point ):
101
97
return nan_to_num (- dlogp (point ))
102
98
103
- if live_disp and callback is None :
104
- callback = Monitor (bij , logp_o , model , grad_logp_o )
105
-
106
99
r = fmin (logp_o , bij .map (start ), fprime = grad_logp_o , callback = callback , * args , ** kwargs )
107
100
compute_gradient = True
108
101
else :
109
- if live_disp and callback is None :
110
- callback = Monitor (bij , logp_o , dlogp = None )
111
-
112
102
# Check to see if minimization function uses a starting value
113
103
if 'x0' in getargspec (fmin ).args :
114
104
r = fmin (logp_o , bij .map (start ), callback = callback , * args , ** kwargs )
@@ -121,12 +111,6 @@ def grad_logp_o(point):
121
111
else :
122
112
mx0 = r
123
113
124
- if live_disp :
125
- try :
126
- callback .update (mx0 )
127
- except :
128
- pass
129
-
130
114
mx = bij .rmap (mx0 )
131
115
132
116
allfinite_mx0 = allfinite (mx0 )
@@ -171,13 +155,16 @@ def message(name, values):
171
155
"density. 2) your distribution logp's are " +
172
156
"properly specified. Specific issues: \n " +
173
157
specific_errors )
174
- mx = {v .name : mx [v .name ].astype (v .dtype ) for v in model .vars }
158
+
159
+ vars = model .unobserved_RVs
160
+ mx = {var .name : value for var , value in zip (vars , model .fastfn (vars )(mx ))}
175
161
176
162
if return_raw :
177
163
return mx , r
178
164
else :
179
165
return mx
180
166
167
+
181
168
def allfinite (x ):
182
169
return np .all (isfinite (x ))
183
170
@@ -192,120 +179,6 @@ def allinmodel(vars, model):
192
179
raise ValueError ("Some variables not in the model: " + str (notin ))
193
180
194
181
195
-
196
- class Monitor (object ):
197
- def __init__ (self , bij , logp , model , dlogp = None ):
198
- try :
199
- from IPython .display import display
200
- from ipywidgets import HTML , VBox , HBox , FlexBox
201
- self .prog_table = HTML (width = '100%' )
202
- self .param_table = HTML (width = '100%' )
203
- r_col = VBox (children = [self .param_table ], padding = 3 , width = '100%' )
204
- l_col = HBox (children = [self .prog_table ], padding = 3 , width = '25%' )
205
- self .hor_align = FlexBox (children = [l_col , r_col ], width = '100%' , orientation = 'vertical' )
206
- display (self .hor_align )
207
- self .using_notebook = True
208
- self .update_interval = 1
209
- except :
210
- self .using_notebook = False
211
- self .update_interval = 2
212
-
213
- self .iters = 0
214
- self .bij = bij
215
- self .model = model
216
- self .fn = model .fastfn (model .unobserved_RVs )
217
- self .logp = logp
218
- self .dlogp = dlogp
219
- self .t_initial = time .time ()
220
- self .t0 = self .t_initial
221
- self .paramtable = {}
222
-
223
- def __call__ (self , x ):
224
- self .iters += 1
225
- if time .time () - self .t0 > self .update_interval or self .iters == 1 :
226
- self .update (x )
227
-
228
- def update (self , x ):
229
- self ._update_progtable (x )
230
- self ._update_paramtable (x )
231
- if self .using_notebook :
232
- self ._display_notebook ()
233
- self .t0 = time .time ()
234
-
235
- def _update_progtable (self , x ):
236
- s = time .time () - self .t_initial
237
- hours , remainder = divmod (int (s ), 3600 )
238
- minutes , seconds = divmod (remainder , 60 )
239
- self .t_elapsed = "{:2d}h{:2d}m{:2d}s" .format (hours , minutes , seconds )
240
- self .logpost = - 1.0 * np .float (self .logp (x ))
241
- self .dlogpost = np .linalg .norm (self .dlogp (x ))
242
-
243
- def _update_paramtable (self , x ):
244
- var_state = self .fn (self .bij .rmap (x ))
245
- for var , val in zip (self .model .unobserved_RVs , var_state ):
246
- if not var .name .endswith ("_" ):
247
- valstr = format_values (val )
248
- self .paramtable [var .name ] = {"size" : val .size , "valstr" : valstr }
249
-
250
- def _display_notebook (self ):
251
- ## Progress table
252
- html = r"""<style type="text/css">
253
- table { border-collapse:collapse }
254
- .tg {border-collapse:collapse;border-spacing:0;border:none;}
255
- .tg td{font-family:Arial, sans-serif;font-size:14px;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;}
256
- .tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal; background-color:#0E688A;color:#ffffff;}
257
- .tg .tg-vkoh{white-space:pre;font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important; background-color:#ffffff;color:#000000}
258
- .tg .tg-suao{font-weight:bold;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
259
- """
260
- html += r"""
261
- </style>
262
- <table class="tg" style="undefined;">
263
- <col width="400px" />
264
- <tr>
265
- <th class= "tg-vkoh">Time Elapsed: {:s}</th>
266
- </tr>
267
- <tr>
268
- <th class= "tg-vkoh">Iteration: {:d}</th>
269
- </tr>
270
- <tr>
271
- <th class= "tg-vkoh">Log Posterior: {:.3f}</th>
272
- </tr>
273
- """ .format (self .t_elapsed , self .iters , self .logpost )
274
- if self .dlogp is not None :
275
- html += r"""
276
- <tr>
277
- <th class= "tg-vkoh">||grad||: {:.3f}</th>
278
- </tr>""" .format (self .dlogpost )
279
- html += "</table>"
280
- self .prog_table .value = html
281
- ## Parameter table
282
- html = r"""<style type="text/css">
283
- .tg .tg-bgft{font-weight:normal;font-family:"Lucida Console", Monaco, monospace !important;background-color:#0E688A;color:#ffffff;}
284
- .tg td{font-family:Arial, sans-serif;font-size:12px;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;color:#333;background-color:#fff;word-wrap: break-word;}
285
- .tg th{Impact, Charcoal, sans-serif;font-size:13px;font-weight:bold;padding:3px 3px;border-style:solid;border-width:1px;overflow:hidden;word-break:normal;border-color:#504A4E;background-color:#0E688A;color:#ffffff;}
286
- </style>
287
- <table class="tg" style="undefined;">
288
- <col width="130px" />
289
- <col width="50px" />
290
- <col width="600px" />
291
- <tr>
292
- <th class="tg">Parameter</th>
293
- <th class="tg">Size</th>
294
- <th class="tg">Current Value</th>
295
- </tr>
296
- """
297
- for var , values in self .paramtable .items ():
298
- html += r"""
299
- <tr>
300
- <td class="tg-bgft">{:s}</td>
301
- <td class="tg-vkoh">{:d}</td>
302
- <td class="tg-vkoh">{:s}</td>
303
- </tr>
304
- """ .format (var , values ["size" ], values ["valstr" ])
305
- html += "</table>"
306
- self .param_table .value = html
307
-
308
-
309
182
def format_values (val ):
310
183
fmt = "{:8.3f}"
311
184
if val .size == 1 :
0 commit comments