@@ -66,37 +66,42 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
66
66
def astep (self , q0 , logp ):
67
67
q0_val = q0 .data
68
68
self .w = np .resize (self .w , len (q0_val )) # this is a repmat
69
- q = np .copy (q0_val ) # TODO: find out if we need this
69
+
70
+ q = np .copy (q0_val )
70
71
ql = np .copy (q0_val ) # l for left boundary
71
72
qr = np .copy (q0_val ) # r for right boudary
72
- for i in range (len (q0_val )):
73
+
74
+ # The points are not copied, so it's fine to update them inplace in the
75
+ # loop below
76
+ q_ra = RaveledVars (q , q0 .point_map_info )
77
+ ql_ra = RaveledVars (ql , q0 .point_map_info )
78
+ qr_ra = RaveledVars (qr , q0 .point_map_info )
79
+
80
+ for i , wi in enumerate (self .w ):
73
81
# uniformly sample from 0 to p(q), but in log space
74
- q_ra = RaveledVars (q , q0 .point_map_info )
75
82
y = logp (q_ra ) - nr .standard_exponential ()
76
83
77
84
# Create initial interval
78
- ql [i ] = q [i ] - nr .uniform () * self . w [ i ] # q[i] + r * w
79
- qr [i ] = ql [i ] + self . w [ i ] # Equivalent to q[i] + (1-r) * w
85
+ ql [i ] = q [i ] - nr .uniform () * wi # q[i] + r * w
86
+ qr [i ] = ql [i ] + wi # Equivalent to q[i] + (1-r) * w
80
87
81
88
# Stepping out procedure
82
89
cnt = 0
83
- while y <= logp (
84
- RaveledVars (ql , q0 .point_map_info )
85
- ): # changed lt to leq for locally uniform posteriors
86
- ql [i ] -= self .w [i ]
90
+ while y <= logp (ql_ra ): # changed lt to leq for locally uniform posteriors
91
+ ql [i ] -= wi
87
92
cnt += 1
88
93
if cnt > self .iter_limit :
89
94
raise RuntimeError (LOOP_ERR_MSG % self .iter_limit )
90
95
cnt = 0
91
- while y <= logp (RaveledVars ( qr , q0 . point_map_info ) ):
92
- qr [i ] += self . w [ i ]
96
+ while y <= logp (qr_ra ):
97
+ qr [i ] += wi
93
98
cnt += 1
94
99
if cnt > self .iter_limit :
95
100
raise RuntimeError (LOOP_ERR_MSG % self .iter_limit )
96
101
97
102
cnt = 0
98
103
q [i ] = nr .uniform (ql [i ], qr [i ])
99
- while logp (q_ra ) < y : # Changed leq to lt, to accomodate for locally flat posteriors
104
+ while y > logp (q_ra ): # Changed leq to lt, to accomodate for locally flat posteriors
100
105
# Sample uniformly from slice
101
106
if q [i ] > q0_val [i ]:
102
107
qr [i ] = q [i ]
@@ -110,7 +115,7 @@ def astep(self, q0, logp):
110
115
if self .tune :
111
116
# I was under impression from MacKays lectures that slice width can be tuned without
112
117
# breaking markovianness. Can we do it regardless of self.tune?(@madanh)
113
- self .w [i ] = self . w [ i ] * (self .n_tunes / (self .n_tunes + 1 )) + (qr [i ] - ql [i ]) / (
118
+ self .w [i ] = wi * (self .n_tunes / (self .n_tunes + 1 )) + (qr [i ] - ql [i ]) / (
114
119
self .n_tunes + 1
115
120
)
116
121
@@ -119,6 +124,7 @@ def astep(self, q0, logp):
119
124
120
125
if self .tune :
121
126
self .n_tunes += 1
127
+
122
128
return q
123
129
124
130
@staticmethod
0 commit comments