@@ -82,8 +82,8 @@ def assign_step_methods(model, step=None, methods=(NUTS, HamiltonianMC, Metropol
82
82
83
83
84
84
def sample (draws , step = None , init = 'advi' , n_init = 200000 , start = None ,
85
- trace = None , chain = 0 , njobs = 1 , tune = None , progressbar = True ,
86
- model = None , random_seed = - 1 ):
85
+ trace = None , thin = 1 , burn = 0 , chain = 0 , njobs = 1 , tune = None ,
86
+ progressbar = True , model = None , random_seed = - 1 ):
87
87
"""
88
88
Draw a number of samples using the given step method.
89
89
Multiple step methods supported via compound step method
@@ -120,6 +120,10 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
120
120
Passing either "text" or "sqlite" is taken as a shortcut to set
121
121
up the corresponding backend (with "mcmc" used as the base
122
122
name).
123
+ thin : int
124
+ Only store every <thin>'th sample.
125
+ burn : int
126
+ Do not store <burn> number of first samples.
123
127
chain : int
124
128
Chain number used to store sample in backend. If `njobs` is
125
129
greater than one, chain numbers will start here.
@@ -159,6 +163,8 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
159
163
sample_args = {'draws' : draws ,
160
164
'step' : step ,
161
165
'start' : start ,
166
+ 'thin' : thin ,
167
+ 'burn' : burn ,
162
168
'trace' : trace ,
163
169
'chain' : chain ,
164
170
'tune' : tune ,
@@ -175,12 +181,13 @@ def sample(draws, step=None, init='advi', n_init=200000, start=None,
175
181
return sample_func (** sample_args )
176
182
177
183
178
- def _sample (draws , step = None , start = None , trace = None , chain = 0 , tune = None ,
179
- progressbar = True , model = None , random_seed = - 1 ):
180
- sampling = _iter_sample (draws , step , start , trace , chain ,
181
- tune , model , random_seed )
184
+ def _sample (draws , step = None , start = None , thin = 1 , burn = 0 , trace = None ,
185
+ chain = 0 , tune = None , progressbar = True , model = None ,
186
+ random_seed = - 1 ):
187
+ sampling = _iter_sample (draws , step , start , thin , burn , trace ,
188
+ chain , tune , model , random_seed )
182
189
if progressbar :
183
- sampling = tqdm (sampling , total = draws )
190
+ sampling = tqdm (sampling , total = round (( draws - burn ) / thin ) )
184
191
try :
185
192
for strace in sampling :
186
193
pass
@@ -189,8 +196,8 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
189
196
return MultiTrace ([strace ])
190
197
191
198
192
- def iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
193
- model = None , random_seed = - 1 ):
199
+ def iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
200
+ chain = 0 , tune = None , model = None , random_seed = - 1 ):
194
201
"""
195
202
Generator that returns a trace on each iteration using the given
196
203
step method. Multiple step methods supported via compound step
@@ -204,6 +211,10 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
204
211
The number of samples to draw
205
212
step : function
206
213
Step function
214
+ thin : int
215
+ Only store every <thin>'th sample.
216
+ burn : int
217
+ Do not store <burn> number of first samples.
207
218
start : dict
208
219
Starting point in parameter space (or partial point)
209
220
Defaults to trace.point(-1)) if there is a trace provided and
@@ -228,14 +239,14 @@ def iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
228
239
for trace in iter_sample(500, step):
229
240
...
230
241
"""
231
- sampling = _iter_sample (draws , step , start , trace , chain , tune ,
232
- model , random_seed )
242
+ sampling = _iter_sample (draws , step , start , thin , burn , trace ,
243
+ chain , tune , model , random_seed )
233
244
for i , strace in enumerate (sampling ):
234
245
yield MultiTrace ([strace [:i + 1 ]])
235
246
236
247
237
- def _iter_sample (draws , step , start = None , trace = None , chain = 0 , tune = None ,
238
- model = None , random_seed = - 1 ):
248
+ def _iter_sample (draws , step , start = None , thin = 1 , burn = 0 , trace = None ,
249
+ chain = 0 , tune = None , model = None , random_seed = - 1 ):
239
250
model = modelcontext (model )
240
251
draws = int (draws )
241
252
if random_seed != - 1 :
@@ -265,8 +276,9 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
265
276
if i == tune :
266
277
step = stop_tuning (step )
267
278
point = step .step (point )
268
- strace .record (point )
269
- yield strace
279
+ if (i % thin == 0 ) and (i >= burn ):
280
+ strace .record (point )
281
+ yield strace
270
282
else :
271
283
strace .close ()
272
284
0 commit comments