Skip to content

Commit 86bb49a

Browse files
michaelosthegejunpenglao
authored andcommitted
Updated DEMetropolis warnings (#3721)
* remove DEMetropolis warning closes #3718 * test for warning about too few chains (#3719) + accelerate existing test by running fewer iterations non-parallelized * warn when DE-MCMC is used with too few chains closes #3719 * fail faster on too small populations + keep old check for added safety
1 parent 1c30a6f commit 86bb49a

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

pymc3/sampling.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
BinaryMetropolis,
2626
BinaryGibbsMetropolis,
2727
CategoricalGibbsMetropolis,
28+
DEMetropolis,
2829
Slice,
2930
CompoundStep,
3031
arraystep,
@@ -480,7 +481,21 @@ def sample(
480481
raise
481482
if not parallel:
482483
if has_population_samplers:
483-
_log.info("Population sampling ({} chains)".format(chains))
484+
has_demcmc = np.any([
485+
isinstance(m, DEMetropolis)
486+
for m in (step.methods if isinstance(step, CompoundStep) else [step])
487+
])
488+
_log.info('Population sampling ({} chains)'.format(chains))
489+
if has_demcmc and chains < 3:
490+
raise ValueError(
491+
'DEMetropolis requires at least 3 chains. ' \
492+
'For this {}-dimensional model you should use ≥{} chains'.format(model.ndim, model.ndim + 1)
493+
)
494+
if has_demcmc and chains <= model.ndim:
495+
warnings.warn(
496+
'DEMetropolis should be used with more chains than dimensions! '
497+
'(The model has {} dimensions.)'.format(model.ndim), UserWarning
498+
)
484499
_print_step_hierarchy(step)
485500
trace = _sample_population(**sample_args, parallelize=cores > 1)
486501
else:

pymc3/step_methods/arraystep.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ def link_population(self, population, chain_index):
212212
self.this_chain = chain_index
213213
self.other_chains = [c for c in range(len(population)) if c != chain_index]
214214
if not len(self.other_chains) > 1:
215-
raise ValueError('Population is just {} + {}. This is too small. You should ' \
216-
'increase the number of chains.'.format(self.this_chain, self.other_chains))
215+
raise ValueError(
216+
'Population is just {} + {}. ' \
217+
'This is too small and the error should have been raised earlier.'.format(self.this_chain, self.other_chains)
218+
)
217219
return
218220

219221

pymc3/step_methods/metropolis.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy.random as nr
33
import theano
44
import scipy.linalg
5-
import warnings
65

76
from ..distributions import draw_values
87
from .arraystep import ArrayStepShared, PopulationArrayStepShared, ArrayStep, metrop_select, Competence
@@ -537,8 +536,6 @@ class DEMetropolis(PopulationArrayStepShared):
537536

538537
def __init__(self, vars=None, S=None, proposal_dist=None, lamb=None, scaling=0.001,
539538
tune=True, tune_interval=100, model=None, mode=None, **kwargs):
540-
warnings.warn('Population based sampling methods such as DEMetropolis are experimental.' \
541-
' Use carefully and be extra critical about their results!')
542539

543540
model = pm.modelcontext(model)
544541

pymc3/tests/test_step.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,21 @@ def test_checks_population_size(self):
702702
for stepper in TestPopulationSamplers.steppers:
703703
step = stepper()
704704
with pytest.raises(ValueError):
705-
trace = sample(draws=100, chains=1, step=step)
706-
trace = sample(draws=100, chains=4, step=step)
705+
sample(draws=10, tune=10, chains=1, cores=1, step=step)
706+
# don't parallelize to make test faster
707+
sample(draws=10, tune=10, chains=4, cores=1, step=step)
708+
pass
709+
710+
def test_demcmc_warning_on_small_populations(self):
711+
"""Test that a warning is raised when n_chains <= n_dims"""
712+
with Model() as model:
713+
Normal("n", mu=0, sigma=1, shape=(2,3))
714+
with pytest.warns(UserWarning) as record:
715+
sample(
716+
draws=5, tune=5, chains=6, step=DEMetropolis(),
717+
# make tests faster by not parallelizing; disable convergence warning
718+
cores=1, compute_convergence_checks=False
719+
)
707720
pass
708721

709722
def test_nonparallelized_chains_are_random(self):

0 commit comments

Comments
 (0)