Skip to content

Commit 8a6d87e

Browse files
authored
Revert "ENH Add burn and thin kwargs to sample." (#1564)
1 parent 1c9adc6 commit 8a6d87e

File tree

2 files changed

+17
-41
lines changed

2 files changed

+17
-41
lines changed

pymc3/sampling.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
8282

8383

8484
def sample(draws, step=None, init='advi', n_init=200000, start=None,
85-
trace=None, thin=1, burn=0, chain=0, njobs=1, tune=None,
86-
progressbar=True, model=None, random_seed=-1):
85+
trace=None, chain=0, njobs=1, tune=None, progressbar=True,
86+
model=None, random_seed=-1):
8787
"""
8888
Draw a number of samples using the given step method.
8989
Multiple step methods supported via compound step method
@@ -120,10 +120,6 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
120120
Passing either "text" or "sqlite" is taken as a shortcut to set
121121
up the corresponding backend (with "mcmc" used as the base
122122
name).
123-
thin : int
124-
Only store every <thin>'th sample.
125-
burn : int
126-
Do not store <burn> number of first samples.
127123
chain : int
128124
Chain number used to store sample in backend. If `njobs` is
129125
greater than one, chain numbers will start here.
@@ -163,8 +159,6 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
163159
sample_args = {'draws': draws,
164160
'step': step,
165161
'start': start,
166-
'thin': thin,
167-
'burn': burn,
168162
'trace': trace,
169163
'chain': chain,
170164
'tune': tune,
@@ -181,13 +175,12 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
181175
return sample_func(**sample_args)
182176

183177

184-
def _sample(draws, step=None, start=None, thin=1, burn=0, trace=None,
185-
chain=0, tune=None, progressbar=True, model=None,
186-
random_seed=-1):
187-
sampling = _iter_sample(draws, step, start, thin, burn, trace,
188-
chain, tune, model, random_seed)
178+
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
179+
progressbar=True, model=None, random_seed=-1):
180+
sampling = _iter_sample(draws, step, start, trace, chain,
181+
tune, model, random_seed)
189182
if progressbar:
190-
sampling = tqdm(sampling, total=round((draws - burn) / thin))
183+
sampling = tqdm(sampling, total=draws)
191184
try:
192185
for strace in sampling:
193186
pass
@@ -196,8 +189,8 @@ def _sample(draws, step=None, start=None, thin=1, burn=0, trace=None,
196189
return MultiTrace([strace])
197190

198191

199-
def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
200-
chain=0, tune=None, model=None, random_seed=-1):
192+
def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
193+
model=None, random_seed=-1):
201194
"""
202195
Generator that returns a trace on each iteration using the given
203196
step method. Multiple step methods supported via compound step
@@ -211,10 +204,6 @@ def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
211204
The number of samples to draw
212205
step : function
213206
Step function
214-
thin : int
215-
Only store every <thin>'th sample.
216-
burn : int
217-
Do not store <burn> number of first samples.
218207
start : dict
219208
Starting point in parameter space (or partial point)
220209
Defaults to trace.point(-1)) if there is a trace provided and
@@ -239,14 +228,14 @@ def iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
239228
for trace in iter_sample(500, step):
240229
...
241230
"""
242-
sampling = _iter_sample(draws, step, start, thin, burn, trace,
243-
chain, tune, model, random_seed)
231+
sampling = _iter_sample(draws, step, start, trace, chain, tune,
232+
model, random_seed)
244233
for i, strace in enumerate(sampling):
245234
yield MultiTrace([strace[:i + 1]])
246235

247236

248-
def _iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
249-
chain=0, tune=None, model=None, random_seed=-1):
237+
def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
238+
model=None, random_seed=-1):
250239
model = modelcontext(model)
251240
draws = int(draws)
252241
if random_seed != -1:
@@ -276,9 +265,8 @@ def _iter_sample(draws, step, start=None, thin=1, burn=0, trace=None,
276265
if i == tune:
277266
step = stop_tuning(step)
278267
point = step.step(point)
279-
if (i % thin == 0) and (i >= burn):
280-
strace.record(point)
281-
yield strace
268+
strace.record(point)
269+
yield strace
282270
else:
283271
strace.close()
284272

pymc3/tests/test_sampling.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,7 @@ def test_sample(self):
5959
with self.model:
6060
for njobs in test_njobs:
6161
for steps in [1, 10, 300]:
62-
pm.sample(steps, self.step, njobs=njobs,
63-
random_seed=self.random_seed)
64-
65-
def test_sample_burn_thin(self):
66-
steps = 100
67-
with self.model:
68-
for burn in [0, 10, 20, 30]:
69-
for thin in [1, 5, 10, 13]:
70-
trace = pm.sample(steps, self.step, burn=burn,
71-
thin=thin,
72-
random_seed=self.random_seed)
73-
assert len(trace) == round((steps - burn) / thin)
62+
pm.sample(steps, self.step, {}, None, njobs=njobs, random_seed=self.random_seed)
7463

7564
def test_sample_init(self):
7665
with self.model:
@@ -82,8 +71,7 @@ def test_sample_init(self):
8271

8372
def test_iter_sample(self):
8473
with self.model:
85-
samps = pm.sampling.iter_sample(5, self.step,
86-
start=self.start, random_seed=self.random_seed)
74+
samps = pm.sampling.iter_sample(5, self.step, self.start, random_seed=self.random_seed)
8775
for i, trace in enumerate(samps):
8876
self.assertEqual(i, len(trace) - 1, "Trace does not have correct length.")
8977

0 commit comments

Comments
 (0)