Skip to content

Commit 70a3e0b

Browse files
ricardoV94michaelosthege
authored andcommitted
Refactor Slice.astep method
Avoid recreation of RaveledArrays in inner loops and repeated indexing of `self.w`
1 parent f28f9cf commit 70a3e0b

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

pymc/step_methods/slicer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,37 +66,42 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, *
6666
def astep(self, q0, logp):
6767
q0_val = q0.data
6868
self.w = np.resize(self.w, len(q0_val)) # this is a repmat
69-
q = np.copy(q0_val) # TODO: find out if we need this
69+
70+
q = np.copy(q0_val)
7071
ql = np.copy(q0_val) # l for left boundary
7172
qr = np.copy(q0_val) # r for right boudary
72-
for i in range(len(q0_val)):
73+
74+
# The points are not copied, so it's fine to update them inplace in the
75+
# loop below
76+
q_ra = RaveledVars(q, q0.point_map_info)
77+
ql_ra = RaveledVars(ql, q0.point_map_info)
78+
qr_ra = RaveledVars(qr, q0.point_map_info)
79+
80+
for i, wi in enumerate(self.w):
7381
# uniformly sample from 0 to p(q), but in log space
74-
q_ra = RaveledVars(q, q0.point_map_info)
7582
y = logp(q_ra) - nr.standard_exponential()
7683

7784
# Create initial interval
78-
ql[i] = q[i] - nr.uniform() * self.w[i] # q[i] + r * w
79-
qr[i] = ql[i] + self.w[i] # Equivalent to q[i] + (1-r) * w
85+
ql[i] = q[i] - nr.uniform() * wi # q[i] + r * w
86+
qr[i] = ql[i] + wi # Equivalent to q[i] + (1-r) * w
8087

8188
# Stepping out procedure
8289
cnt = 0
83-
while y <= logp(
84-
RaveledVars(ql, q0.point_map_info)
85-
): # changed lt to leq for locally uniform posteriors
86-
ql[i] -= self.w[i]
90+
while y <= logp(ql_ra): # changed lt to leq for locally uniform posteriors
91+
ql[i] -= wi
8792
cnt += 1
8893
if cnt > self.iter_limit:
8994
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
9095
cnt = 0
91-
while y <= logp(RaveledVars(qr, q0.point_map_info)):
92-
qr[i] += self.w[i]
96+
while y <= logp(qr_ra):
97+
qr[i] += wi
9398
cnt += 1
9499
if cnt > self.iter_limit:
95100
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
96101

97102
cnt = 0
98103
q[i] = nr.uniform(ql[i], qr[i])
99-
while logp(q_ra) < y: # Changed leq to lt, to accomodate for locally flat posteriors
104+
while y > logp(q_ra): # Changed leq to lt, to accomodate for locally flat posteriors
100105
# Sample uniformly from slice
101106
if q[i] > q0_val[i]:
102107
qr[i] = q[i]
@@ -110,7 +115,7 @@ def astep(self, q0, logp):
110115
if self.tune:
111116
# I was under impression from MacKays lectures that slice width can be tuned without
112117
# breaking markovianness. Can we do it regardless of self.tune?(@madanh)
113-
self.w[i] = self.w[i] * (self.n_tunes / (self.n_tunes + 1)) + (qr[i] - ql[i]) / (
118+
self.w[i] = wi * (self.n_tunes / (self.n_tunes + 1)) + (qr[i] - ql[i]) / (
114119
self.n_tunes + 1
115120
)
116121

@@ -119,6 +124,7 @@ def astep(self, q0, logp):
119124

120125
if self.tune:
121126
self.n_tunes += 1
127+
122128
return q
123129

124130
@staticmethod

0 commit comments

Comments
 (0)