|
17 | 17 | from .backends.ndarray import NDArray
|
18 | 18 | from .distributions.distribution import draw_values
|
19 | 19 | from .model import modelcontext, Point, all_continuous, Model
|
20 |
| -from .step_methods import (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis, |
| 20 | +from .step_methods import (NUTS, HamiltonianMC, Metropolis, DEMetropolis, BinaryMetropolis, |
21 | 21 | BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
|
22 | 22 | Slice, CompoundStep, arraystep)
|
23 | 23 | from .util import update_start_vals, get_untransformed_name, is_transformed_name, get_default_varnames
|
@@ -420,7 +420,16 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None, trace=N
|
420 | 420 | raise
|
421 | 421 | if not parallel:
|
422 | 422 | if has_population_samplers:
|
| 423 | + has_demcmc = np.any([ |
| 424 | + isinstance(m, DEMetropolis) |
| 425 | + for m in (step.methods if isinstance(step, CompoundStep) else [step]) |
| 426 | + ]) |
423 | 427 | _log.info('Population sampling ({} chains)'.format(chains))
|
| 428 | + if has_demcmc and chains <= model.ndim: |
| 429 | + warnings.warn( |
| 430 | + 'DEMetropolis should be used with more chains than dimensions! ' |
| 431 | + '(The model has {} dimensions.)'.format(model.ndim), UserWarning |
| 432 | + ) |
424 | 433 | _print_step_hierarchy(step)
|
425 | 434 | trace = _sample_population(**sample_args, parallelize=cores > 1)
|
426 | 435 | else:
|
|
0 commit comments