Skip to content

Commit 5efe09e

Browse files
committed
Optimize NUTS
1 parent 8708f21 commit 5efe09e

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

pymc/step_methods/hmc/nuts.py

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
import numpy as np
2121

22-
from pymc.math import logbern
23-
from pymc.pytensorf import floatX
22+
from pytensor import config
23+
2424
from pymc.stats.convergence import SamplerWarning
2525
from pymc.step_methods.compound import Competence
2626
from pymc.step_methods.hmc import integration
@@ -205,11 +205,12 @@ def _hamiltonian_step(self, start, p0, step_size):
205205
else:
206206
max_treedepth = self.max_treedepth
207207

208-
tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=self.rng)
208+
rng = self.rng
209+
tree = _Tree(len(p0), self.integrator, start, step_size, self.Emax, rng=rng)
209210

210211
reached_max_treedepth = False
211212
for _ in range(max_treedepth):
212-
direction = logbern(np.log(0.5), rng=self.rng) * 2 - 1
213+
direction = (rng.random() < 0.5) * 2 - 1
213214
divergence_info, turning = tree.extend(direction)
214215

215216
if divergence_info or turning:
@@ -218,9 +219,8 @@ def _hamiltonian_step(self, start, p0, step_size):
218219
reached_max_treedepth = not self.tune
219220

220221
stats = tree.stats()
221-
accept_stat = stats["mean_tree_accept"]
222222
stats["reached_max_treedepth"] = reached_max_treedepth
223-
return HMCStepData(tree.proposal, accept_stat, divergence_info, stats)
223+
return HMCStepData(tree.proposal, stats["mean_tree_accept"], divergence_info, stats)
224224

225225
@staticmethod
226226
def competence(var, has_grad):
@@ -241,6 +241,27 @@ def competence(var, has_grad):
241241

242242

243243
class _Tree:
244+
__slots__ = (
245+
"ndim",
246+
"integrator",
247+
"start",
248+
"step_size",
249+
"Emax",
250+
"start_energy",
251+
"rng",
252+
"left",
253+
"right",
254+
"proposal",
255+
"depth",
256+
"log_size",
257+
"log_accept_sum",
258+
"mean_tree_accept",
259+
"n_proposals",
260+
"p_sum",
261+
"max_energy_change",
262+
"floatX",
263+
)
264+
244265
def __init__(
245266
self,
246267
ndim: int,
@@ -273,14 +294,15 @@ def __init__(
273294
self.rng = rng
274295

275296
self.left = self.right = start
276-
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, start.model_logp, 0)
297+
self.proposal = Proposal(start.q, start.q_grad, start.energy, start.model_logp, 0)
277298
self.depth = 0
278299
self.log_size = 0.0
279300
self.log_accept_sum = -np.inf
280301
self.mean_tree_accept = 0.0
281302
self.n_proposals = 0
282303
self.p_sum = start.p.copy()
283304
self.max_energy_change = 0.0
305+
self.floatX = config.floatX
284306

285307
def extend(self, direction):
286308
"""Double the treesize by extending the tree in the given direction.
@@ -296,7 +318,7 @@ def extend(self, direction):
296318
"""
297319
if direction > 0:
298320
tree, diverging, turning = self._build_subtree(
299-
self.right, self.depth, floatX(np.asarray(self.step_size))
321+
self.right, self.depth, np.asarray(self.step_size, dtype=self.floatX)
300322
)
301323
leftmost_begin, leftmost_end = self.left, self.right
302324
rightmost_begin, rightmost_end = tree.left, tree.right
@@ -305,7 +327,7 @@ def extend(self, direction):
305327
self.right = tree.right
306328
else:
307329
tree, diverging, turning = self._build_subtree(
308-
self.left, self.depth, floatX(np.asarray(-self.step_size))
330+
self.left, self.depth, np.asarray(-self.step_size, dtype=self.floatX)
309331
)
310332
leftmost_begin, leftmost_end = tree.right, tree.left
311333
rightmost_begin, rightmost_end = self.left, self.right
@@ -318,23 +340,27 @@ def extend(self, direction):
318340
if diverging or turning:
319341
return diverging, turning
320342

321-
size1, size2 = self.log_size, tree.log_size
322-
if logbern(size2 - size1, rng=self.rng):
343+
self_log_size, tree_log_size = self.log_size, tree.log_size
344+
if np.log(self.rng.random()) < (tree_log_size - self_log_size):
323345
self.proposal = tree.proposal
324346

325-
self.log_size = np.logaddexp(self.log_size, tree.log_size)
326-
self.p_sum[:] += tree.p_sum
347+
self.log_size = np.logaddexp(tree_log_size, self_log_size)
348+
349+
p_sum = self.p_sum
350+
p_sum[:] += tree.p_sum
327351

328352
# Additional turning check only when tree depth > 0 to avoid redundant work
329353
if self.depth > 0:
330354
left, right = self.left, self.right
331-
p_sum = self.p_sum
332355
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
333-
p_sum1 = leftmost_p_sum + rightmost_begin.p
334-
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
335-
p_sum2 = leftmost_end.p + rightmost_p_sum
336-
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
337-
turning = turning | turning1 | turning2
356+
if not turning:
357+
p_sum1 = leftmost_p_sum + rightmost_begin.p
358+
turning = (p_sum1.dot(leftmost_begin.v) <= 0) or (
359+
p_sum1.dot(rightmost_begin.v) <= 0
360+
)
361+
if not turning:
362+
p_sum2 = leftmost_end.p + rightmost_p_sum
363+
turning = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
338364

339365
return diverging, turning
340366

@@ -356,7 +382,10 @@ def _single_step(self, left: State, epsilon: float):
356382
if np.isnan(energy_change):
357383
energy_change = np.inf
358384

359-
self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))
385+
self.log_accept_sum = np.logaddexp(
386+
self.log_accept_sum, (-energy_change if energy_change > 0 else 0)
387+
)
388+
# self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))
360389

361390
if np.abs(energy_change) > np.abs(self.max_energy_change):
362391
self.max_energy_change = energy_change
@@ -366,7 +395,7 @@ def _single_step(self, left: State, epsilon: float):
366395
# Saturated Metropolis accept probability with Boltzmann weight
367396
log_size = -energy_change
368397
proposal = Proposal(
369-
right.q.data,
398+
right.q,
370399
right.q_grad,
371400
right.energy,
372401
right.model_logp,
@@ -399,15 +428,15 @@ def _build_subtree(self, left, depth, epsilon):
399428
p_sum = tree1.p_sum + tree2.p_sum
400429
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
401430
# Additional U turn check only when depth > 1 to avoid redundant work.
402-
if depth - 1 > 0:
431+
if (not turning) and (depth - 1 > 0):
403432
p_sum1 = tree1.p_sum + tree2.left.p
404-
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
405-
p_sum2 = tree1.right.p + tree2.p_sum
406-
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
407-
turning = turning | turning1 | turning2
433+
turning = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
434+
if not turning:
435+
p_sum2 = tree1.right.p + tree2.p_sum
436+
turning = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
408437

409438
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
410-
if logbern(tree2.log_size - log_size, rng=self.rng):
439+
if np.log(self.rng.random()) < (tree2.log_size - log_size):
411440
proposal = tree2.proposal
412441
else:
413442
proposal = tree1.proposal

0 commit comments

Comments
 (0)