Skip to content

Commit a01d015

Browse files
junpenglaotwiecki
authored andcommitted
Implement robust U-turn check (#3605)
* [WIP] Robust U-turn check Following the recent discussion on the Stan side: stan-dev/stan#2800 For experiment, do not merge. * typo fix * bug fix * Additional U turn check only when depth > 1 (to avoid redundant work). * further logic to reduce redundant U Turn check. * bug fix fix error in recording the end point of the reversed subtree * [WIP] Robust U-turn check Following the recent discussion on the Stan side: stan-dev/stan#2800 For experiment, do not merge. * typo fix * bug fix * Additional U turn check only when depth > 1 (to avoid redundant work). * further logic to reduce redundant U Turn check. * bug fix fix error in recording the end point of the reversed subtree * Add release note.
1 parent 530bc41 commit a01d015

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## PyMC3 3.8 (on deck)
44

55
### New features
6+
- Implemented robust u turn check in NUTS (similar to stan-dev/stan#2800). See PR [#3605]
67
- Add capabilities to do inference on parameters in a differential equation with `DifferentialEquation`. See [#3590](https://github.com/pymc-devs/pymc3/pull/3590).
78
- Distinguish between `Data` and `Deterministic` variables when graphing models with graphviz. PR [#3491](https://github.com/pymc-devs/pymc3/pull/3491).
89
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.

pymc3/step_methods/hmc/nuts.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,18 @@ def extend(self, direction):
251251
if direction > 0:
252252
tree, diverging, turning = self._build_subtree(
253253
self.right, self.depth, floatX(np.asarray(self.step_size)))
254+
leftmost_begin, leftmost_end = self.left, self.right
255+
rightmost_begin, rightmost_end = tree.left, tree.right
256+
leftmost_p_sum = self.p_sum
257+
rightmost_p_sum = tree.p_sum
254258
self.right = tree.right
255259
else:
256260
tree, diverging, turning = self._build_subtree(
257261
self.left, self.depth, floatX(np.asarray(-self.step_size)))
262+
leftmost_begin, leftmost_end = tree.right, tree.left
263+
rightmost_begin, rightmost_end = self.left, self.right
264+
leftmost_p_sum = tree.p_sum
265+
rightmost_p_sum = self.p_sum
258266
self.left = tree.right
259267

260268
self.depth += 1
@@ -271,9 +279,16 @@ def extend(self, direction):
271279
self.log_size = np.logaddexp(self.log_size, tree.log_size)
272280
self.p_sum[:] += tree.p_sum
273281

274-
left, right = self.left, self.right
275-
p_sum = self.p_sum
276-
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
282+
# Additional turning check only when tree depth > 0 to avoid redundant work
283+
if self.depth > 0:
284+
left, right = self.left, self.right
285+
p_sum = self.p_sum
286+
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
287+
p_sum1 = leftmost_p_sum + rightmost_begin.p
288+
turning1 = (p_sum1.dot(leftmost_begin.v) <= 0) or (p_sum1.dot(rightmost_begin.v) <= 0)
289+
p_sum2 = leftmost_end.p + rightmost_p_sum
290+
turning2 = (p_sum2.dot(leftmost_end.v) <= 0) or (p_sum2.dot(rightmost_end.v) <= 0)
291+
turning = (turning | turning1 | turning2)
277292

278293
return diverging, turning
279294

@@ -324,6 +339,13 @@ def _build_subtree(self, left, depth, epsilon):
324339
if not (diverging or turning):
325340
p_sum = tree1.p_sum + tree2.p_sum
326341
turning = (p_sum.dot(left.v) <= 0) or (p_sum.dot(right.v) <= 0)
342+
# Additional U turn check only when depth > 1 to avoid redundant work.
343+
if depth - 1 > 0:
344+
p_sum1 = tree1.p_sum + tree2.left.p
345+
turning1 = (p_sum1.dot(tree1.left.v) <= 0) or (p_sum1.dot(tree2.left.v) <= 0)
346+
p_sum2 = tree1.right.p + tree2.p_sum
347+
turning2 = (p_sum2.dot(tree1.right.v) <= 0) or (p_sum2.dot(tree2.right.v) <= 0)
348+
turning = (turning | turning1 | turning2)
327349

328350
log_size = np.logaddexp(tree1.log_size, tree2.log_size)
329351
if logbern(tree2.log_size - log_size):

0 commit comments

Comments
 (0)