Skip to content

Commit f61d8cd

Browse files
authored
ENH Add burn and thin kwargs to sample. (#1562)
* ENH Add burn and thin kwargs to sample. * MAINT Move new thin and burn behind start kwarg. Make test use kwargs.
1 parent 2caa005 commit f61d8cd

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

pymc3/sampling.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8282

8383

8484
def sample(draws, step=None, init='advi', n_init=200000, start=None,
85-
trace=None, chain=0, njobs=1, tune=None, progressbar=True,
86-
model=None, random_seed=-1):
85+
trace=None, thin=1, burn=0, chain=0, njobs=1, tune=None,
86+
progressbar=True, model=None, random_seed=-1):
8787
"""
8888
Draw a number of samples using the given step method.
8989
Multiple step methods supported via compound step method
@@ -120,6 +120,10 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
120120
Passing either "text" or "sqlite" is taken as a shortcut to set
121121
up the corresponding backend (with "mcmc" used as the base
122122
name).
123+
thin : int
124+
Only store every <thin>'th sample.
125+
burn : int
126+
Do not store <burn> number of first samples.
123127
chain : int
124128
Chain number used to store sample in backend. If `njobs` is
125129
greater than one, chain numbers will start here.
@@ -159,6 +163,8 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
159163
sample_args = {'draws': draws,
160164
'step': step,
161165
'start': start,
166+
'thin': thin,
167+
'burn': burn,
162168
'trace': trace,
163169
'chain': chain,
164170
'tune': tune,
@@ -175,12 +181,13 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
175181
return sample_func(**sample_args)
176182

177183

178-
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
179-
progressbar=True, model=None, random_seed=-1):
180-
sampling = _iter_sample(draws, step, start, trace, chain,
181-
tune, model, random_seed)
184+
def _sample(draws, step=None, start=None, thin=1, burn=0, trace=None,
185+
chain=0, tune=None, progressbar=True, model=None,
186+
random_seed=-1):
187+
sampling = _iter_sample(draws, step, start, thin, burn, trace,
188+
chain, tune, model, random_seed)
182189
if progressbar:
183-
sampling = tqdm(sampling, total=draws)
190+
sampling = tqdm(sampling, total=round((draws - burn) / thin))
184191
try:
185192
for strace in sampling:
186193
pass
@@ -189,8 +196,8 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
189196
return MultiTrace([strace])
190197

191198

192-
def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
193-
model=None, random_seed=-1):
199+
def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
200+
chain=0, tune=None, model=None, random_seed=-1):
194201
"""
195202
Generator that returns a trace on each iteration using the given
196203
step method. Multiple step methods supported via compound step
@@ -204,6 +211,10 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
204211
The number of samples to draw
205212
step : function
206213
Step function
214+
thin : int
215+
Only store every <thin>'th sample.
216+
burn : int
217+
Do not store <burn> number of first samples.
207218
start : dict
208219
Starting point in parameter space (or partial point)
209220
Defaults to trace.point(-1)) if there is a trace provided and
@@ -228,14 +239,14 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
228239
for trace in iter_sample(500, step):
229240
...
230241
"""
231-
sampling = _iter_sample(draws, step, start, trace, chain, tune,
232-
model, random_seed)
242+
sampling = _iter_sample(draws, step, start, thin, burn, trace,
243+
chain, tune, model, random_seed)
233244
for i, strace in enumerate(sampling):
234245
yield MultiTrace([strace[:i + 1]])
235246

236247

237-
def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
238-
model=None, random_seed=-1):
248+
def _iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
249+
chain=0, tune=None, model=None, random_seed=-1):
239250
model = modelcontext(model)
240251
draws = int(draws)
241252
if random_seed != -1:
@@ -265,8 +276,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
265276
if i == tune:
266277
step = stop_tuning(step)
267278
point = step.step(point)
268-
strace.record(point)
269-
yield strace
279+
if (i % thin == 0) and (i >= burn):
280+
strace.record(point)
281+
yield strace
270282
else:
271283
strace.close()
272284

pymc3/tests/test_sampling.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,18 @@ def test_sample(self):
5959
with self.model:
6060
for njobs in test_njobs:
6161
for steps in [1, 10, 300]:
62-
pm.sample(steps, self.step, {}, None, njobs=njobs, random_seed=self.random_seed)
62+
pm.sample(steps, self.step, njobs=njobs,
63+
random_seed=self.random_seed)
64+
65+
def test_sample_burn_thin(self):
66+
steps = 100
67+
with self.model:
68+
for burn in [0, 10, 20, 30]:
69+
for thin in [1, 5, 10, 13]:
70+
trace = pm.sample(steps, self.step, burn=burn,
71+
thin=thin,
72+
random_seed=self.random_seed)
73+
assert len(trace) == round((steps - burn) / thin)
6374

6475
def test_sample_init(self):
6576
with self.model:
@@ -71,7 +82,8 @@ def test_sample_init(self):
7182

7283
def test_iter_sample(self):
7384
with self.model:
74-
samps = pm.sampling.iter_sample(5, self.step, self.start, random_seed=self.random_seed)
85+
samps = pm.sampling.iter_sample(5, self.step,
86+
start=self.start, random_seed=self.random_seed)
7587
for i, trace in enumerate(samps):
7688
self.assertEqual(i, len(trace) - 1, "Trace does not have correct length.")
7789

0 commit comments

Comments
 (0)