19
19
20
20
import numpy as np
21
21
22
- from pymc . math import logbern
23
- from pymc . pytensorf import floatX
22
+ from pytensor import config
23
+
24
24
from pymc .stats .convergence import SamplerWarning
25
25
from pymc .step_methods .compound import Competence
26
26
from pymc .step_methods .hmc import integration
@@ -205,11 +205,12 @@ def _hamiltonian_step(self, start, p0, step_size):
205
205
else :
206
206
max_treedepth = self .max_treedepth
207
207
208
- tree = _Tree (len (p0 ), self .integrator , start , step_size , self .Emax , rng = self .rng )
208
+ rng = self .rng
209
+ tree = _Tree (len (p0 ), self .integrator , start , step_size , self .Emax , rng = rng )
209
210
210
211
reached_max_treedepth = False
211
212
for _ in range (max_treedepth ):
212
- direction = logbern ( np . log ( 0.5 ), rng = self . rng ) * 2 - 1
213
+ direction = ( rng . random () < 0.5 ) * 2 - 1
213
214
divergence_info , turning = tree .extend (direction )
214
215
215
216
if divergence_info or turning :
@@ -218,9 +219,8 @@ def _hamiltonian_step(self, start, p0, step_size):
218
219
reached_max_treedepth = not self .tune
219
220
220
221
stats = tree .stats ()
221
- accept_stat = stats ["mean_tree_accept" ]
222
222
stats ["reached_max_treedepth" ] = reached_max_treedepth
223
- return HMCStepData (tree .proposal , accept_stat , divergence_info , stats )
223
+ return HMCStepData (tree .proposal , stats [ "mean_tree_accept" ] , divergence_info , stats )
224
224
225
225
@staticmethod
226
226
def competence (var , has_grad ):
@@ -241,6 +241,27 @@ def competence(var, has_grad):
241
241
242
242
243
243
class _Tree :
244
+ __slots__ = (
245
+ "ndim" ,
246
+ "integrator" ,
247
+ "start" ,
248
+ "step_size" ,
249
+ "Emax" ,
250
+ "start_energy" ,
251
+ "rng" ,
252
+ "left" ,
253
+ "right" ,
254
+ "proposal" ,
255
+ "depth" ,
256
+ "log_size" ,
257
+ "log_accept_sum" ,
258
+ "mean_tree_accept" ,
259
+ "n_proposals" ,
260
+ "p_sum" ,
261
+ "max_energy_change" ,
262
+ "floatX" ,
263
+ )
264
+
244
265
def __init__ (
245
266
self ,
246
267
ndim : int ,
@@ -273,14 +294,15 @@ def __init__(
273
294
self .rng = rng
274
295
275
296
self .left = self .right = start
276
- self .proposal = Proposal (start .q . data , start .q_grad , start .energy , start .model_logp , 0 )
297
+ self .proposal = Proposal (start .q , start .q_grad , start .energy , start .model_logp , 0 )
277
298
self .depth = 0
278
299
self .log_size = 0.0
279
300
self .log_accept_sum = - np .inf
280
301
self .mean_tree_accept = 0.0
281
302
self .n_proposals = 0
282
303
self .p_sum = start .p .copy ()
283
304
self .max_energy_change = 0.0
305
+ self .floatX = config .floatX
284
306
285
307
def extend (self , direction ):
286
308
"""Double the treesize by extending the tree in the given direction.
@@ -296,7 +318,7 @@ def extend(self, direction):
296
318
"""
297
319
if direction > 0 :
298
320
tree , diverging , turning = self ._build_subtree (
299
- self .right , self .depth , floatX ( np .asarray (self .step_size ) )
321
+ self .right , self .depth , np .asarray (self .step_size , dtype = self . floatX )
300
322
)
301
323
leftmost_begin , leftmost_end = self .left , self .right
302
324
rightmost_begin , rightmost_end = tree .left , tree .right
@@ -305,7 +327,7 @@ def extend(self, direction):
305
327
self .right = tree .right
306
328
else :
307
329
tree , diverging , turning = self ._build_subtree (
308
- self .left , self .depth , floatX ( np .asarray (- self .step_size ) )
330
+ self .left , self .depth , np .asarray (- self .step_size , dtype = self . floatX )
309
331
)
310
332
leftmost_begin , leftmost_end = tree .right , tree .left
311
333
rightmost_begin , rightmost_end = self .left , self .right
@@ -318,23 +340,27 @@ def extend(self, direction):
318
340
if diverging or turning :
319
341
return diverging , turning
320
342
321
- size1 , size2 = self .log_size , tree .log_size
322
- if logbern ( size2 - size1 , rng = self .rng ):
343
+ self_log_size , tree_log_size = self .log_size , tree .log_size
344
+ if np . log ( self .rng . random ()) < ( tree_log_size - self_log_size ):
323
345
self .proposal = tree .proposal
324
346
325
- self .log_size = np .logaddexp (self .log_size , tree .log_size )
326
- self .p_sum [:] += tree .p_sum
347
+ self .log_size = np .logaddexp (tree_log_size , self_log_size )
348
+
349
+ p_sum = self .p_sum
350
+ p_sum [:] += tree .p_sum
327
351
328
352
# Additional turning check only when tree depth > 0 to avoid redundant work
329
353
if self .depth > 0 :
330
354
left , right = self .left , self .right
331
- p_sum = self .p_sum
332
355
turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
333
- p_sum1 = leftmost_p_sum + rightmost_begin .p
334
- turning1 = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (p_sum1 .dot (rightmost_begin .v ) <= 0 )
335
- p_sum2 = leftmost_end .p + rightmost_p_sum
336
- turning2 = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (p_sum2 .dot (rightmost_end .v ) <= 0 )
337
- turning = turning | turning1 | turning2
356
+ if not turning :
357
+ p_sum1 = leftmost_p_sum + rightmost_begin .p
358
+ turning = (p_sum1 .dot (leftmost_begin .v ) <= 0 ) or (
359
+ p_sum1 .dot (rightmost_begin .v ) <= 0
360
+ )
361
+ if not turning :
362
+ p_sum2 = leftmost_end .p + rightmost_p_sum
363
+ turning = (p_sum2 .dot (leftmost_end .v ) <= 0 ) or (p_sum2 .dot (rightmost_end .v ) <= 0 )
338
364
339
365
return diverging , turning
340
366
@@ -356,7 +382,10 @@ def _single_step(self, left: State, epsilon: float):
356
382
if np .isnan (energy_change ):
357
383
energy_change = np .inf
358
384
359
- self .log_accept_sum = np .logaddexp (self .log_accept_sum , min (0 , - energy_change ))
385
+ self .log_accept_sum = np .logaddexp (
386
+ self .log_accept_sum , (- energy_change if energy_change > 0 else 0 )
387
+ )
388
+ # self.log_accept_sum = np.logaddexp(self.log_accept_sum, min(0, -energy_change))
360
389
361
390
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
362
391
self .max_energy_change = energy_change
@@ -366,7 +395,7 @@ def _single_step(self, left: State, epsilon: float):
366
395
# Saturated Metropolis accept probability with Boltzmann weight
367
396
log_size = - energy_change
368
397
proposal = Proposal (
369
- right .q . data ,
398
+ right .q ,
370
399
right .q_grad ,
371
400
right .energy ,
372
401
right .model_logp ,
@@ -399,15 +428,15 @@ def _build_subtree(self, left, depth, epsilon):
399
428
p_sum = tree1 .p_sum + tree2 .p_sum
400
429
turning = (p_sum .dot (left .v ) <= 0 ) or (p_sum .dot (right .v ) <= 0 )
401
430
# Additional U turn check only when depth > 1 to avoid redundant work.
402
- if depth - 1 > 0 :
431
+ if ( not turning ) and ( depth - 1 > 0 ) :
403
432
p_sum1 = tree1 .p_sum + tree2 .left .p
404
- turning1 = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (p_sum1 .dot (tree2 .left .v ) <= 0 )
405
- p_sum2 = tree1 . right . p + tree2 . p_sum
406
- turning2 = ( p_sum2 . dot ( tree1 . right . v ) <= 0 ) or ( p_sum2 . dot ( tree2 . right .v ) <= 0 )
407
- turning = turning | turning1 | turning2
433
+ turning = (p_sum1 .dot (tree1 .left .v ) <= 0 ) or (p_sum1 .dot (tree2 .left .v ) <= 0 )
434
+ if not turning :
435
+ p_sum2 = tree1 . right .p + tree2 . p_sum
436
+ turning = ( p_sum2 . dot ( tree1 . right . v ) <= 0 ) or ( p_sum2 . dot ( tree2 . right . v ) <= 0 )
408
437
409
438
log_size = np .logaddexp (tree1 .log_size , tree2 .log_size )
410
- if logbern ( tree2 .log_size - log_size , rng = self . rng ):
439
+ if np . log ( self . rng . random ()) < ( tree2 .log_size - log_size ):
411
440
proposal = tree2 .proposal
412
441
else :
413
442
proposal = tree1 .proposal
0 commit comments