Skip to content

Commit ce7e64a

Browse files
authored
Improve target criterion for step size adaptation (#3656)
* [WIP] Improve target criterion for step size adaptation close #3655 See stan-dev/stan#2789 for detail analysis. * debug test fail attempt 1 * bug fix * fix tests * formatting edit * fix test
1 parent e4c775b commit ce7e64a

File tree

5 files changed

+131
-115
lines changed

5 files changed

+131
-115
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,12 @@ def warnings(self):
209209

210210

211211
# A proposal for the next position
212-
Proposal = namedtuple("Proposal", "q, q_grad, energy, p_accept, logp")
212+
Proposal = namedtuple("Proposal", "q, q_grad, energy, log_p_accept, logp")
213213

214214
# A subtree of the binary tree built by nuts.
215215
Subtree = namedtuple(
216216
"Subtree",
217-
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
217+
"left, right, p_sum, proposal, log_size, log_accept_sum, n_proposals")
218218

219219

220220
class _Tree:
@@ -245,7 +245,8 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
245245
start.q, start.q_grad, start.energy, 1.0, start.model_logp)
246246
self.depth = 0
247247
self.log_size = 0
248-
self.accept_sum = 0
248+
self.log_accept_sum = -np.inf
249+
self.mean_tree_accept = 0.
249250
self.n_proposals = 0
250251
self.p_sum = start.p.copy()
251252
self.max_energy_change = 0
@@ -280,7 +281,6 @@ def extend(self, direction):
280281
self.left = tree.right
281282

282283
self.depth += 1
283-
self.accept_sum += tree.accept_sum
284284
self.n_proposals += tree.n_proposals
285285

286286
if diverging or turning:
@@ -291,6 +291,8 @@ def extend(self, direction):
291291
self.proposal = tree.proposal
292292

293293
self.log_size = np.logaddexp(self.log_size, tree.log_size)
294+
self.log_accept_sum = np.logaddexp(self.log_accept_sum,
295+
tree.log_accept_sum)
294296
self.p_sum[:] += tree.p_sum
295297

296298
# Additional turning check only when tree depth > 0 to avoid redundant work
@@ -314,25 +316,31 @@ def _single_step(self, left, epsilon):
314316
error_msg = str(err)
315317
error = err
316318
else:
319+
# h - H0
317320
energy_change = right.energy - self.start_energy
318321
if np.isnan(energy_change):
319322
energy_change = np.inf
320323

321324
if np.abs(energy_change) > np.abs(self.max_energy_change):
322325
self.max_energy_change = energy_change
323326
if np.abs(energy_change) < self.Emax:
324-
p_accept = min(1, np.exp(-energy_change))
327+
# Acceptance statistic
328+
# e^{H(q_0, p_0) - H(q_n, p_n)} max(1, e^{H(q_0, p_0) - H(q_n, p_n)})
329+
# Saturated Metropolis accept probability with Boltzmann weight
330+
# if h - H0 < 0
331+
log_p_accept = -energy_change + min(0., -energy_change)
325332
log_size = -energy_change
326333
proposal = Proposal(
327-
right.q, right.q_grad, right.energy, p_accept, right.model_logp)
334+
right.q, right.q_grad, right.energy, log_p_accept,
335+
right.model_logp)
328336
tree = Subtree(right, right, right.p,
329-
proposal, log_size, p_accept, 1)
337+
proposal, log_size, log_p_accept, 1)
330338
return tree, None, False
331339
else:
332340
error_msg = ("Energy change in leapfrog step is too large: %s."
333341
% energy_change)
334342
error = None
335-
tree = Subtree(None, None, None, None, -np.inf, 0, 1)
343+
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
336344
divergance_info = DivergenceInfo(error_msg, error, left)
337345
return tree, divergance_info, False
338346

@@ -362,26 +370,34 @@ def _build_subtree(self, left, depth, epsilon):
362370
turning = (turning | turning1 | turning2)
363371

364372
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
373+
log_accept_sum = np.logaddexp(tree1.log_accept_sum,
374+
tree2.log_accept_sum)
365375
if logbern(tree2.log_size - log_size):
366376
proposal = tree2.proposal
367377
else:
368378
proposal = tree1.proposal
369379
else:
370380
p_sum = tree1.p_sum
371381
log_size = tree1.log_size
382+
log_accept_sum = tree1.log_accept_sum
372383
proposal = tree1.proposal
373384

374-
accept_sum = tree1.accept_sum + tree2.accept_sum
375385
n_proposals = tree1.n_proposals + tree2.n_proposals
376386

377387
tree = Subtree(left, right, p_sum, proposal,
378-
log_size, accept_sum, n_proposals)
388+
log_size, log_accept_sum, n_proposals)
379389
return tree, diverging, turning
380390

381391
def stats(self):
392+
# Update accept stat if any subtrees were accepted
393+
if self.log_size > 0:
394+
# Remove contribution from initial state which is always a perfect
395+
# accept
396+
sum_weight = np.expm1(self.log_size)
397+
self.mean_tree_accept = np.exp(self.log_accept_sum) / sum_weight
382398
return {
383399
'depth': self.depth,
384-
'mean_tree_accept': self.accept_sum / self.n_proposals,
400+
'mean_tree_accept': self.mean_tree_accept,
385401
'energy_error': self.proposal.energy - self.start.energy,
386402
'energy': self.proposal.energy,
387403
'tree_size': self.n_proposals,

pymc3/tests/test_data_container.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_sample(self):
3838
pm.Normal('obs', b * x_shared, np.sqrt(1e-2), observed=y)
3939
prior_trace0 = pm.sample_prior_predictive(1000)
4040

41-
trace = pm.sample(1000, init=None, progressbar=False)
41+
trace = pm.sample(1000, init=None, tune=1000, chains=1)
4242
pp_trace0 = pm.sample_posterior_predictive(trace, 1000)
4343

4444
x_shared.set_value(x_pred)
@@ -79,11 +79,11 @@ def test_sample_after_set_data(self):
7979
pm.Normal('obs', beta * x, np.sqrt(1e-2), observed=y)
8080
pm.sample(1000, init=None, tune=1000, chains=1)
8181
# Predict on new data.
82-
new_x = [5, 6, 9]
83-
new_y = [5, 6, 9]
82+
new_x = [5., 6., 9.]
83+
new_y = [5., 6., 9.]
8484
with model:
8585
pm.set_data(new_data={'x': new_x, 'y': new_y})
86-
new_trace = pm.sample()
86+
new_trace = pm.sample(1000, init=None, tune=1000, chains=1)
8787
pp_trace = pm.sample_posterior_predictive(new_trace, 1000)
8888

8989
assert pp_trace['obs'].shape == (1000, 3)

pymc3/tests/test_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_glm_link_func2(self):
9898
with Model() as model:
9999
GLM.from_formula('y ~ x', self.data_logistic2,
100100
family=families.Binomial(priors={'n': self.data_logistic2['n']}))
101-
trace = sample(1000, progressbar=False,
101+
trace = sample(1000, progressbar=False, init='adapt_diag',
102102
random_seed=self.random_seed)
103103

104104
assert round(abs(np.mean(trace['Intercept'])-self.intercept), 1) == 0

pymc3/tests/test_shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_sample(self):
3939
pm.Normal('obs', b * x_shared, np.sqrt(1e-2), observed=y)
4040
prior_trace0 = pm.sample_prior_predictive(1000)
4141

42-
trace = pm.sample(1000, init=None, progressbar=False)
42+
trace = pm.sample(1000, init=None, tune=1000, chains=1)
4343
pp_trace0 = pm.sample_posterior_predictive(trace, 1000)
4444

4545
x_shared.set_value(x_pred)

pymc3/tests/test_step.py

Lines changed: 98 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -372,104 +372,104 @@ class TestStepMethods: # yield test doesn't work subclassing object
372372
[
373373
0.550575,
374374
0.550575,
375-
0.80046332,
376-
0.91590059,
377-
1.34621916,
378-
1.34621916,
379-
-0.63917773,
380-
-0.65770809,
381-
-0.65770809,
382-
-0.64512868,
383-
-1.05448153,
384-
-0.5225666,
385-
0.14335153,
386-
-0.0034499,
387-
-0.0034499,
388-
0.05309212,
389-
-0.53186371,
390-
0.29325825,
391-
0.43210854,
392-
0.56284837,
393-
0.56284837,
394-
0.38041767,
395-
0.47322034,
396-
0.49937368,
397-
0.49937368,
398-
0.44424258,
399-
0.44424258,
400-
-0.02790848,
401-
-0.40470145,
402-
-0.35725567,
403-
-0.43744228,
404-
0.41955432,
405-
0.31099421,
406-
0.31099421,
407-
0.65811717,
408-
0.66649398,
409-
0.38493786,
410-
0.54114658,
411-
0.54114658,
412-
0.68222408,
413-
0.66404942,
414-
1.44143108,
415-
1.15638799,
416-
-0.06775775,
417-
-0.06775775,
418-
0.30418561,
419-
0.23543403,
420-
0.57934404,
421-
-0.5435111,
422-
-0.47938915,
423-
-0.23816662,
424-
0.36793792,
425-
0.36793792,
426-
0.64980016,
427-
0.52150456,
428-
0.64643321,
429-
0.26130179,
430-
1.10569077,
431-
1.10569077,
432-
1.23662797,
433-
-0.36928735,
434-
-0.14303069,
435-
0.85298904,
436-
0.85298904,
437-
0.31422085,
438-
0.32113762,
439-
0.32113762,
440-
1.0692238,
441-
1.0692238,
442-
1.60127576,
443-
1.49249738,
444-
1.09065107,
445-
0.84264371,
446-
0.84264371,
447-
-0.08832343,
448-
0.04868027,
449-
-0.02679449,
450-
-0.02679449,
451-
0.91989101,
452-
0.65754478,
453-
-0.39220625,
454-
0.08379492,
455-
1.03055634,
456-
1.03055634,
457-
1.71071332,
458-
1.58740483,
459-
1.67905741,
460-
0.77744868,
461-
0.15050587,
462-
0.15050587,
463-
0.73979127,
464-
0.15445515,
465-
0.13134717,
466-
0.85068974,
467-
0.85068974,
468-
0.6974799,
469-
0.16170472,
470-
0.86405959,
471-
0.86405959,
472-
-0.22032854,
375+
0.80031201,
376+
0.91580544,
377+
1.34622953,
378+
1.34622953,
379+
-0.63861533,
380+
-0.62101385,
381+
-0.62101385,
382+
-0.60250375,
383+
-1.04753424,
384+
-0.34850626,
385+
0.35882649,
386+
-0.20339408,
387+
-0.18077466,
388+
-0.18077466,
389+
0.1242007,
390+
-0.48708213,
391+
0.01216292,
392+
0.01216292,
393+
-0.15991487,
394+
0.0118306,
395+
0.0118306,
396+
0.02512962,
397+
-0.06002705,
398+
0.61278464,
399+
-0.45991609,
400+
-0.45991609,
401+
-0.45991609,
402+
-0.3067988,
403+
-0.3067988,
404+
-0.30830273,
405+
-0.62877494,
406+
-0.5896293,
407+
0.32740518,
408+
0.32740518,
409+
0.55321326,
410+
0.34885231,
411+
0.34885231,
412+
0.35304997,
413+
1.20016133,
414+
1.20016133,
415+
1.26432486,
416+
1.22481613,
417+
1.46040499,
418+
1.2251786,
419+
0.29954482,
420+
0.29954482,
421+
0.5713582,
422+
0.5755183,
423+
0.26968846,
424+
0.68253483,
425+
0.68253483,
426+
0.69418724,
427+
1.4172782,
428+
1.4172782,
429+
0.85063608,
430+
0.23409974,
431+
-0.65012501,
432+
1.16211157,
433+
-0.04844954,
434+
1.34390994,
435+
-0.44058335,
436+
-0.44058335,
437+
0.85096033,
438+
0.98734074,
439+
1.31200906,
440+
1.2751574,
441+
1.2751574,
442+
0.04377635,
443+
0.08244824,
444+
0.6342471,
445+
-0.31243596,
446+
1.0165907,
447+
-0.19025897,
448+
-0.19025897,
449+
0.02133041,
450+
-0.02335463,
451+
0.43923434,
452+
-0.45033488,
453+
0.05985518,
454+
-0.10019701,
455+
1.34229104,
456+
1.28571862,
457+
0.59557205,
458+
0.63730268,
459+
0.63730268,
460+
0.54269992,
461+
0.54269992,
462+
-0.48334519,
463+
1.02199273,
464+
-0.17367903,
465+
-0.17367903,
466+
0.8470911,
467+
-0.12868214,
468+
1.8986946,
469+
1.55412619,
470+
1.55412619,
471+
0.90228003,
472+
1.3328478
473473
]
474474
),
475475
}

0 commit comments

Comments
 (0)