@@ -49,49 +49,47 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
49
49
self .k = k
50
50
self .t0 = t0
51
51
52
- self .Hbar = 0
52
+ self .h_bar = 0
53
53
self .u = np .log (self .step_size * 10 )
54
54
self .m = 1
55
55
56
56
def astep (self , q0 ):
57
- Emax = self .Emax
58
- e = self .step_size
59
-
60
57
p0 = self .potential .random ()
61
- E0 = self .compute_energy (q0 , p0 )
58
+ start_energy = self .compute_energy (q0 , p0 )
62
59
63
60
u = nr .uniform ()
64
61
q = qn = qp = q0
65
- p = pn = pp = p0
62
+ pn = pp = p0
63
+
64
+ tree_size , depth = 1. , 0
65
+ keep_sampling = True
66
66
67
- n , s , j = 1 , 1 , 0
67
+ while keep_sampling :
68
+ direction = bern (0.5 ) * 2 - 1
69
+ q_edge , p_edge = {- 1 : (qn , pn ), 1 : (qp , pp )}[direction ]
68
70
69
- while s == 1 :
70
- v = bern (0.5 ) * 2 - 1
71
+ q_edge , p_edge , proposal , subtree_size , is_valid_sample , a , na = buildtree (
72
+ self .leapfrog , q_edge , p_edge ,
73
+ u , direction , depth ,
74
+ self .step_size , self .Emax , start_energy )
71
75
72
- if v == - 1 :
73
- qn , pn , _ , _ , q1 , n1 , s1 , a , na = buildtree (
74
- self .leapfrog , qn , pn , u , v , j , e , Emax , E0 )
76
+ if direction == - 1 :
77
+ qn , pn = q_edge , p_edge
75
78
else :
76
- _ , _ , qp , pp , q1 , n1 , s1 , a , na = buildtree (
77
- self .leapfrog , qp , pp , u , v , j , e , Emax , E0 )
79
+ qp , pp = q_edge , p_edge
78
80
79
- if s1 == 1 and bern (min (1 , n1 * 1. / n )):
80
- q = q1
81
+ if is_valid_sample and bern (min (1 , subtree_size / tree_size )):
82
+ q = proposal
81
83
82
- n = n + n1
84
+ tree_size += subtree_size
83
85
84
86
span = qp - qn
85
- s = s1 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
86
- j = j + 1
87
-
88
- p = - p
87
+ keep_sampling = is_valid_sample and (span .dot (pn ) >= 0 ) and (span .dot (pp ) >= 0 )
88
+ depth += 1
89
89
90
90
w = 1. / (self .m + self .t0 )
91
- self .Hbar = (1 - w ) * self .Hbar + w * \
92
- (self .target_accept - a * 1. / na )
93
-
94
- self .step_size = np .exp (self .u - (self .m ** self .k / self .gamma ) * self .Hbar )
91
+ self .h_bar = (1 - w ) * self .h_bar + w * (self .target_accept - a * 1. / na )
92
+ self .step_size = np .exp (self .u - (self .m ** self .k / self .gamma ) * self .h_bar )
95
93
self .m += 1
96
94
97
95
return q
@@ -103,30 +101,33 @@ def competence(var):
103
101
return Competence .INCOMPATIBLE
104
102
105
103
106
- def buildtree (leapfrog , q , p , u , v , j , e , Emax , E0 ):
107
- if j == 0 :
108
- q1 , p1 , E = leapfrog (q , p , np .array (v * e ))
109
- dE = E - E0
110
-
111
- n1 = int (np .log (u ) + dE <= 0 )
112
- s1 = int (np .log (u ) + dE < Emax )
113
- return q1 , p1 , q1 , p1 , q1 , n1 , s1 , min (1 , np .exp (- dE )), 1
114
- qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1 = buildtree (leapfrog , q , p , u , v , j - 1 , e , Emax , E0 )
115
- if s1 == 1 :
116
- if v == - 1 :
117
- qn , pn , _ , _ , q11 , n11 , s11 , a11 , na11 = buildtree (
118
- leapfrog , qn , pn , u , v , j - 1 , e , Emax , E0 )
119
- else :
120
- _ , _ , qp , pp , q11 , n11 , s11 , a11 , na11 = buildtree (
121
- leapfrog , qp , pp , u , v , j - 1 , e , Emax , E0 )
122
-
123
- if bern (n11 * 1. / (max (n1 + n11 , 1 ))):
124
- q1 = q11
125
-
126
- a1 = a1 + a11
127
- na1 = na1 + na11
128
-
129
- span = qp - qn
130
- s1 = s11 * (span .dot (pn ) >= 0 ) * (span .dot (pp ) >= 0 )
131
- n1 = n1 + n11
132
- return qn , pn , qp , pp , q1 , n1 , s1 , a1 , na1
104
+ def buildtree (leapfrog , q , p , u , direction , depth , step_size , Emax , start_energy ):
105
+ if depth == 0 :
106
+ q_edge , p_edge , new_energy = leapfrog (q , p , np .array (direction * step_size ))
107
+ energy_change = new_energy - start_energy
108
+
109
+ leaf_size = int (np .log (u ) + energy_change <= 0 )
110
+ is_valid_sample = (np .log (u ) + energy_change < Emax )
111
+ return q_edge , p_edge , q_edge , leaf_size , is_valid_sample , min (1 , np .exp (- energy_change )), 1
112
+ else :
113
+ depth -= 1
114
+
115
+ q , p , proposal , tree_size , is_valid_sample , a1 , na1 = buildtree (
116
+ leapfrog , q , p , u , direction , depth , step_size , Emax , start_energy )
117
+
118
+ if is_valid_sample :
119
+ q_edge , p_edge , new_proposal , subtree_size , is_valid_subsample , a11 , na11 = buildtree (
120
+ leapfrog , q , p , u , direction , depth , step_size , Emax , start_energy )
121
+
122
+ tree_size += subtree_size
123
+ if bern (subtree_size * 1. / max (tree_size , 1 )):
124
+ proposal = new_proposal
125
+
126
+ a1 += a11
127
+ na1 += na11
128
+ span = direction * (q_edge - q )
129
+ is_valid_sample = is_valid_subsample and (span .dot (p_edge ) >= 0 ) and (span .dot (p ) >= 0 )
130
+ else :
131
+ q_edge , p_edge = q , p
132
+
133
+ return q_edge , p_edge , proposal , tree_size , is_valid_sample , a1 , na1
0 commit comments