Skip to content

Commit 577c46b

Browse files
committed
Add sample_stats to Slice sampler
1 parent dd55284 commit 577c46b

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

pymc/step_methods/slicer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ class Slice(ArrayStep):
4747

4848
name = "slice"
4949
default_blocked = False
50+
generates_stats = True
51+
stats_dtypes = [
52+
{
53+
"nstep_out": int,
54+
"nstep_in": int,
55+
}
56+
]
5057

5158
def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
5259
self.model = modelcontext(model)
@@ -67,6 +74,8 @@ def astep(self, q0, logp):
6774
q0_val = q0.data
6875
self.w = np.resize(self.w, len(q0_val)) # this is a repmat
6976

77+
nstep_out = nstep_in = 0
78+
7079
q = np.copy(q0_val)
7180
ql = np.copy(q0_val) # l for left boundary
7281
qr = np.copy(q0_val) # r for right boudary
@@ -92,12 +101,15 @@ def astep(self, q0, logp):
92101
cnt += 1
93102
if cnt > self.iter_limit:
94103
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
104+
nstep_out += cnt
105+
95106
cnt = 0
96107
while y <= logp(qr_ra):
97108
qr[i] += wi
98109
cnt += 1
99110
if cnt > self.iter_limit:
100111
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
112+
nstep_out += cnt
101113

102114
cnt = 0
103115
q[i] = nr.uniform(ql[i], qr[i])
@@ -111,6 +123,7 @@ def astep(self, q0, logp):
111123
cnt += 1
112124
if cnt > self.iter_limit:
113125
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
126+
nstep_in += cnt
114127

115128
if self.tune:
116129
# I was under impression from MacKays lectures that slice width can be tuned without
@@ -125,7 +138,12 @@ def astep(self, q0, logp):
125138
if self.tune:
126139
self.n_tunes += 1
127140

128-
return q
141+
stats = {
142+
"nstep_out": nstep_out,
143+
"nstep_in": nstep_in,
144+
}
145+
146+
return q, (stats,)
129147

130148
@staticmethod
131149
def competence(var, has_grad):

0 commit comments

Comments
 (0)