Skip to content

Commit de83381

Browse files
Add progressbar to sample_smc and deprecate parallel (#4826)
* Add progressbar to `sample_smc` and deprecate `parallel` Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Michael Osthege <[email protected]>
1 parent 13487d0 commit de83381

File tree

2 files changed

+93
-25
lines changed

2 files changed

+93
-25
lines changed

pymc3/smc/sample_smc.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323

2424
from arviz import InferenceData
25+
from fastprogress.fastprogress import progress_bar
2526

2627
import pymc3
2728

@@ -45,12 +46,13 @@ def sample_smc(
4546
save_log_pseudolikelihood=True,
4647
model=None,
4748
random_seed=-1,
48-
parallel=False,
49+
parallel=None,
4950
chains=None,
5051
cores=None,
5152
compute_convergence_checks=True,
5253
return_inferencedata=True,
5354
idata_kwargs=None,
55+
progressbar=True,
5456
):
5557
r"""
5658
Sequential Monte Carlo based sampling.
@@ -90,12 +92,9 @@ def sample_smc(
9092
model: Model (optional if in ``with`` context)).
9193
random_seed: int
9294
random seed
93-
parallel: bool
94-
Distribute computations across cores if the number of cores is larger than 1.
95-
Defaults to False.
9695
cores : int
9796
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
98-
system, but at most 4.
97+
system.
9998
chains : int
10099
The number of chains to sample. Running independent chains is important for some
101100
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
@@ -108,6 +107,9 @@ def sample_smc(
108107
Defaults to ``True``.
109108
idata_kwargs : dict, optional
110109
Keyword arguments for :func:`pymc3.to_inference_data`
110+
progressbar : bool, optional default=True
111+
Whether or not to display a progress bar in the command line.
112+
111113
Notes
112114
-----
113115
SMC works by moving through successive stages. At each stage the inverse temperature
@@ -153,6 +155,16 @@ def sample_smc(
153155
816-832. `link <http://ascelibrary.org/doi/abs/10.1061/%28ASCE%290733-9399
154156
%282007%29133:7%28816%29>`__
155157
"""
158+
159+
if parallel is not None:
160+
warnings.warn(
161+
"The argument parallel is deprecated, use the argument cores instead.",
162+
DeprecationWarning,
163+
stacklevel=2,
164+
)
165+
if parallel is False:
166+
cores = 1
167+
156168
_log = logging.getLogger("pymc3")
157169
_log.info("Initializing SMC sampler...")
158170

@@ -206,19 +218,26 @@ def sample_smc(
206218
)
207219

208220
t1 = time.time()
209-
if parallel and chains > 1:
210-
loggers = [_log] + [None] * (chains - 1)
221+
if cores > 1:
222+
pbar = progress_bar((), total=100, display=progressbar)
223+
pbar.update(0)
224+
pbars = [pbar] + [None] * (chains - 1)
225+
211226
pool = mp.Pool(cores)
212227
results = pool.starmap(
213-
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
228+
sample_smc_int, [(*params, random_seed[i], i, pbars[i]) for i in range(chains)]
214229
)
215-
216230
pool.close()
217231
pool.join()
232+
218233
else:
219234
results = []
235+
pbar = progress_bar((), total=100 * chains, display=progressbar)
236+
pbar.update(0)
220237
for i in range(chains):
221-
results.append(sample_smc_int(*params, random_seed[i], i, _log))
238+
pbar.offset = 100 * i
239+
pbar.base_comment = f"Chain: {i+1}/{chains}"
240+
results.append(sample_smc_int(*params, random_seed[i], i, pbar))
222241

223242
(
224243
traces,
@@ -310,7 +329,7 @@ def sample_smc_int(
310329
model,
311330
random_seed,
312331
chain,
313-
_log,
332+
progressbar=None,
314333
):
315334
"""Run one SMC instance."""
316335
smc = SMC(
@@ -331,14 +350,22 @@ def sample_smc_int(
331350
betas = []
332351
accept_ratios = []
333352
nsteps = []
353+
354+
if progressbar:
355+
progressbar.comment = f"{getattr(progressbar, 'base_comment', '')} Stage: 0 Beta: 0"
356+
progressbar.update_bar(getattr(progressbar, "offset", 0) + 0)
357+
334358
smc.initialize_population()
335359
smc.setup_kernel()
336360
smc.initialize_logp()
337361

338362
while smc.beta < 1:
339363
smc.update_weights_beta()
340-
if _log is not None:
341-
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
364+
if progressbar:
365+
progressbar.comment = (
366+
f"{getattr(progressbar, 'base_comment', '')} Stage: {stage} Beta: {smc.beta:.3f}"
367+
)
368+
progressbar.update_bar(getattr(progressbar, "offset", 0) + int(smc.beta * 100))
342369
smc.update_proposal()
343370
smc.resample()
344371
smc.mutate()

pymc3/tests/test_smc.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import time
16+
1517
import aesara
1618
import aesara.tensor as at
1719
import numpy as np
@@ -60,9 +62,22 @@ def two_gaussians(x):
6062

6163
self.muref = mu1
6264

65+
with pm.Model() as self.fast_model:
66+
x = pm.Normal("x", 0, 1)
67+
y = pm.Normal("y", x, 1, observed=0)
68+
69+
with pm.Model() as self.slow_model:
70+
x = pm.Normal("x", 0, 1)
71+
y = pm.Normal("y", x, 1, observed=100)
72+
6373
def test_sample(self):
6474
with self.SMC_test:
65-
mtrace = pm.sample_smc(draws=self.samples, return_inferencedata=False)
75+
76+
mtrace = pm.sample_smc(
77+
draws=self.samples,
78+
cores=1, # Fails in parallel due to #4799
79+
return_inferencedata=False,
80+
)
6681

6782
x = mtrace["X"]
6883
mu1d = np.abs(x).mean(axis=0)
@@ -107,39 +122,65 @@ def test_slowdown_warning(self):
107122
with pm.Model() as model:
108123
a = pm.Poisson("a", 5)
109124
y = pm.Normal("y", a, 5, observed=[1, 2, 3, 4])
110-
trace = pm.sample_smc(draws=100, chains=2)
125+
trace = pm.sample_smc(draws=100, chains=2, cores=1)
111126

112127
@pytest.mark.parametrize("chains", (1, 2))
113128
def test_return_datatype(self, chains):
114129
draws = 10
115130

116-
with pm.Model() as m:
117-
x = pm.Normal("x", 0, 1)
118-
y = pm.Normal("y", x, 1, observed=5)
119-
131+
with self.fast_model:
120132
idata = pm.sample_smc(chains=chains, draws=draws)
121133
mt = pm.sample_smc(chains=chains, draws=draws, return_inferencedata=False)
122134

123135
assert isinstance(idata, InferenceData)
124136
assert "sample_stats" in idata
125-
assert len(idata.posterior.chain) == chains
126-
assert len(idata.posterior.draw) == draws
137+
assert idata.posterior.dims["chain"] == chains
138+
assert idata.posterior.dims["draw"] == draws
127139

128140
assert isinstance(mt, MultiTrace)
129141
assert mt.nchains == chains
130142
assert mt["x"].size == chains * draws
131143

132144
def test_convergence_checks(self):
133-
with pm.Model() as m:
134-
x = pm.Normal("x", 0, 1)
135-
y = pm.Normal("y", x, 1, observed=5)
136-
145+
with self.fast_model:
137146
with pytest.warns(
138147
UserWarning,
139148
match="The number of samples is too small",
140149
):
141150
pm.sample_smc(draws=99)
142151

152+
def test_parallel_sampling(self):
153+
# Cache graph
154+
with self.slow_model:
155+
_ = pm.sample_smc(draws=10, chains=1, cores=1, return_inferencedata=False)
156+
157+
chains = 4
158+
draws = 100
159+
160+
t0 = time.time()
161+
with self.slow_model:
162+
idata = pm.sample_smc(draws=draws, chains=chains, cores=4)
163+
t_mp = time.time() - t0
164+
assert idata.posterior.dims["chain"] == chains
165+
assert idata.posterior.dims["draw"] == draws
166+
167+
t0 = time.time()
168+
with self.slow_model:
169+
idata = pm.sample_smc(draws=draws, chains=chains, cores=1)
170+
t_seq = time.time() - t0
171+
assert idata.posterior.dims["chain"] == chains
172+
assert idata.posterior.dims["draw"] == draws
173+
174+
assert t_mp < t_seq
175+
176+
def test_depracated_parallel_arg(self):
177+
with self.fast_model:
178+
with pytest.warns(
179+
DeprecationWarning,
180+
match="The argument parallel is deprecated",
181+
):
182+
pm.sample_smc(draws=10, chains=1, parallel=False)
183+
143184

144185
@pytest.mark.xfail(reason="SMC-ABC not refactored yet")
145186
class TestSMCABC(SeededTest):

0 commit comments

Comments
 (0)