Skip to content

Commit 4cef603

Browse files
colintwiecki
colin
authored andcommitted
Refactor and make more descriptive
1 parent 4b2f66c commit 4cef603

File tree

3 files changed

+55
-54
lines changed

3 files changed

+55
-54
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 53 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -49,49 +49,47 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
4949
self.k = k
5050
self.t0 = t0
5151

52-
self.Hbar = 0
52+
self.h_bar = 0
5353
self.u = np.log(self.step_size * 10)
5454
self.m = 1
5555

5656
def astep(self, q0):
57-
Emax = self.Emax
58-
e = self.step_size
59-
6057
p0 = self.potential.random()
61-
E0 = self.compute_energy(q0, p0)
58+
start_energy = self.compute_energy(q0, p0)
6259

6360
u = nr.uniform()
6461
q = qn = qp = q0
65-
p = pn = pp = p0
62+
pn = pp = p0
63+
64+
tree_size, depth = 1., 0
65+
keep_sampling = True
6666

67-
n, s, j = 1, 1, 0
67+
while keep_sampling:
68+
direction = bern(0.5) * 2 - 1
69+
q_edge, p_edge = {-1: (qn, pn), 1: (qp, pp)}[direction]
6870

69-
while s == 1:
70-
v = bern(0.5) * 2 - 1
71+
q_edge, p_edge, proposal, subtree_size, is_valid_sample, a, na = buildtree(
72+
self.leapfrog, q_edge, p_edge,
73+
u, direction, depth,
74+
self.step_size, self.Emax, start_energy)
7175

72-
if v == -1:
73-
qn, pn, _, _, q1, n1, s1, a, na = buildtree(
74-
self.leapfrog, qn, pn, u, v, j, e, Emax, E0)
76+
if direction == -1:
77+
qn, pn = q_edge, p_edge
7578
else:
76-
_, _, qp, pp, q1, n1, s1, a, na = buildtree(
77-
self.leapfrog, qp, pp, u, v, j, e, Emax, E0)
79+
qp, pp = q_edge, p_edge
7880

79-
if s1 == 1 and bern(min(1, n1 * 1. / n)):
80-
q = q1
81+
if is_valid_sample and bern(min(1, subtree_size / tree_size)):
82+
q = proposal
8183

82-
n = n + n1
84+
tree_size += subtree_size
8385

8486
span = qp - qn
85-
s = s1 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
86-
j = j + 1
87-
88-
p = -p
87+
keep_sampling = is_valid_sample and (span.dot(pn) >= 0) and (span.dot(pp) >= 0)
88+
depth += 1
8989

9090
w = 1. / (self.m + self.t0)
91-
self.Hbar = (1 - w) * self.Hbar + w * \
92-
(self.target_accept - a * 1. / na)
93-
94-
self.step_size = np.exp(self.u - (self.m**self.k / self.gamma) * self.Hbar)
91+
self.h_bar = (1 - w) * self.h_bar + w * (self.target_accept - a * 1. / na)
92+
self.step_size = np.exp(self.u - (self.m**self.k / self.gamma) * self.h_bar)
9593
self.m += 1
9694

9795
return q
@@ -103,30 +101,33 @@ def competence(var):
103101
return Competence.INCOMPATIBLE
104102

105103

106-
def buildtree(leapfrog, q, p, u, v, j, e, Emax, E0):
107-
if j == 0:
108-
q1, p1, E = leapfrog(q, p, np.array(v * e))
109-
dE = E - E0
110-
111-
n1 = int(np.log(u) + dE <= 0)
112-
s1 = int(np.log(u) + dE < Emax)
113-
return q1, p1, q1, p1, q1, n1, s1, min(1, np.exp(-dE)), 1
114-
qn, pn, qp, pp, q1, n1, s1, a1, na1 = buildtree(leapfrog, q, p, u, v, j - 1, e, Emax, E0)
115-
if s1 == 1:
116-
if v == -1:
117-
qn, pn, _, _, q11, n11, s11, a11, na11 = buildtree(
118-
leapfrog, qn, pn, u, v, j - 1, e, Emax, E0)
119-
else:
120-
_, _, qp, pp, q11, n11, s11, a11, na11 = buildtree(
121-
leapfrog, qp, pp, u, v, j - 1, e, Emax, E0)
122-
123-
if bern(n11 * 1. / (max(n1 + n11, 1))):
124-
q1 = q11
125-
126-
a1 = a1 + a11
127-
na1 = na1 + na11
128-
129-
span = qp - qn
130-
s1 = s11 * (span.dot(pn) >= 0) * (span.dot(pp) >= 0)
131-
n1 = n1 + n11
132-
return qn, pn, qp, pp, q1, n1, s1, a1, na1
104+
def buildtree(leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy):
105+
if depth == 0:
106+
q_edge, p_edge, new_energy = leapfrog(q, p, np.array(direction * step_size))
107+
energy_change = new_energy - start_energy
108+
109+
leaf_size = int(np.log(u) + energy_change <= 0)
110+
is_valid_sample = (np.log(u) + energy_change < Emax)
111+
return q_edge, p_edge, q_edge, leaf_size, is_valid_sample, min(1, np.exp(-energy_change)), 1
112+
else:
113+
depth -= 1
114+
115+
q, p, proposal, tree_size, is_valid_sample, a1, na1 = buildtree(
116+
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)
117+
118+
if is_valid_sample:
119+
q_edge, p_edge, new_proposal, subtree_size, is_valid_subsample, a11, na11 = buildtree(
120+
leapfrog, q, p, u, direction, depth, step_size, Emax, start_energy)
121+
122+
tree_size += subtree_size
123+
if bern(subtree_size * 1. / max(tree_size, 1)):
124+
proposal = new_proposal
125+
126+
a1 += a11
127+
na1 += na11
128+
span = direction * (q_edge - q)
129+
is_valid_sample = is_valid_subsample and (span.dot(p_edge) >= 0) and (span.dot(p) >= 0)
130+
else:
131+
q_edge, p_edge = q, p
132+
133+
return q_edge, p_edge, proposal, tree_size, is_valid_sample, a1, na1

pymc3/step_methods/hmc/trajectory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def get_theano_hamiltonian_functions(model_vars, shared, logpt, potential,
116116

117117
def energy(H, q, p):
118118
"""Compute the total energy for the Hamiltonian at a given position/momentum"""
119-
return -(H.logp(q) - H.pot.energy(p))
119+
return H.pot.energy(p) - H.logp(q)
120120

121121

122122
def leapfrog(H, q, p, epsilon, n_steps):

pymc3/tests/test_distributions_timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..distributions.timeseries import EulerMaruyama
66
from ..tuning.starting import find_MAP
77
from ..sampling import sample, sample_ppc
8-
from ..step_methods.nuts import NUTS
8+
from ..step_methods import NUTS
99

1010
import numpy as np
1111
from scipy.optimize import fmin_l_bfgs_b

0 commit comments

Comments
 (0)