Skip to content

Commit 53fa46b

Browse files
author
Junpeng Lao
authored
Merge pull request #2339 from hvasbath/smc_model_llk
Remove necessity to define likelihood variable inside model, is done …
2 parents b5f3271 + cac698c commit 53fa46b

File tree

5 files changed

+61
-91
lines changed

5 files changed

+61
-91
lines changed

docs/source/notebooks/SMC2_gaussians.ipynb

Lines changed: 29 additions & 33 deletions
Large diffs are not rendered by default.

pymc3/step_methods/metropolis.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ def __init__(self, s):
5050
self.chol = scipy.linalg.cholesky(s, lower=True)
5151

5252
def __call__(self, num_draws=None):
53-
b = np.random.randn(self.n)
54-
return np.dot(self.chol, b)
53+
if num_draws is not None:
54+
b = np.random.randn(self.n, num_draws)
55+
return np.dot(self.chol, b).T
56+
else:
57+
b = np.random.randn(self.n)
58+
return np.dot(self.chol, b)
5559

5660

5761
class Metropolis(ArrayStepShared):

pymc3/step_methods/smc.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,15 @@
2525
from ..theanof import inputvars, make_shared_replacements, join_nonshared_inputs
2626
import numpy.random as nr
2727

28+
from .metropolis import MultivariateNormalProposal
2829
from .arraystep import metrop_select
2930
from ..backends import smc_text as atext
3031

31-
__all__ = ['SMC', 'ATMIP_sample']
32+
__all__ = ['SMC', 'sample_smc']
3233

3334
EXPERIMENTAL_WARNING = "Warning: SMC is an experimental step method, and not yet"\
3435
" recommended for use in PyMC3!"
3536

36-
37-
class Proposal(object):
38-
"""Proposal distributions modified from pymc3 to initially create all the
39-
Proposal steps without repeated execution of the RNG - significant speedup!
40-
41-
Parameters
42-
----------
43-
s : :class:`numpy.ndarray`
44-
"""
45-
def __init__(self, s):
46-
self.s = np.atleast_1d(s)
47-
48-
49-
class MultivariateNormalProposal(Proposal):
50-
def __call__(self, num_draws=None):
51-
return np.random.multivariate_normal(
52-
mean=np.zeros(self.s.shape[0]), cov=self.s, size=num_draws)
53-
54-
5537
proposal_dists = {
5638
'MultivariateNormal': MultivariateNormalProposal,
5739
}
@@ -147,6 +129,13 @@ def __init__(self, vars=None, out_vars=None, n_chains=100, scaling=1., covarianc
147129
vars = inputvars(vars)
148130

149131
if out_vars is None:
132+
if not any(likelihood_name == RV.name for RV in model.unobserved_RVs):
133+
with model:
134+
llk = pm.Deterministic(likelihood_name, model.logpt)
135+
else:
136+
raise ValueError(
137+
'The model likelihood name is already being used by a RV!')
138+
150139
out_vars = model.unobserved_RVs
151140

152141
out_varnames = [out_var.name for out_var in out_vars]
@@ -419,9 +408,9 @@ def resample(self):
419408
return outindx
420409

421410

422-
def ATMIP_sample(n_steps, step=None, start=None, homepath=None, chain=0, stage=0, n_jobs=1,
411+
def sample_smc(n_steps, step=None, start=None, homepath=None, chain=0, stage=0, n_jobs=1,
423412
tune=None, progressbar=False, model=None, random_seed=-1, rm_flag=False):
424-
"""(C)ATMIP sampling algorithm (Cascading - (C) not always relevant)
413+
"""Sequential Monte Carlo sampling
425414
426415
Samples the solution space with n_chains of Metropolis chains, where each
427416
chain has n_steps iterations. Once finished, the sampled traces are
@@ -524,25 +513,8 @@ def ATMIP_sample(n_steps, step=None, start=None, homepath=None, chain=0, stage=0
524513
draws = step.n_steps
525514

526515
stage_handler.clean_directory(stage, None, rm_flag)
527-
with model:
528-
chains = stage_handler.recover_existing_results(stage, draws, step, n_jobs)
529-
if chains is not None:
530-
rest = len(chains) % n_jobs
531-
if rest > 0:
532-
pm._log.info('Fixing %i chains ...' % rest)
533-
chains, rest_chains = chains[:-rest], chains[-rest:]
534-
# process traces that are not a multiple of n_jobs
535-
sample_args = {
536-
'draws': draws,
537-
'step': step,
538-
'stage_path': stage_handler.stage_path(stage),
539-
'progressbar': progressbar,
540-
'model': model,
541-
'n_jobs': rest,
542-
'chains': rest_chains}
543516

544-
_iter_parallel_chains(**sample_args)
545-
pm._log.info('Back to normal!')
517+
chains = stage_handler.recover_existing_results(stage, draws, step, n_jobs)
546518

547519
with model:
548520
while step.beta < 1:
@@ -556,7 +528,7 @@ def ATMIP_sample(n_steps, step=None, start=None, homepath=None, chain=0, stage=0
556528
pm._log.info('Beta: %f Stage: %i' % (step.beta, step.stage))
557529

558530
# Metropolis sampling intermediate stages
559-
chains = stage_handler.clean_directory(stage, chains, rm_flag)
531+
chains = stage_handler.clean_directory(step.stage, chains, rm_flag)
560532
sample_args = {
561533
'draws': draws,
562534
'step': step,

pymc3/tests/test_smc.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,25 @@ def two_gaussians(x):
5151
upper=2. * np.ones_like(mu1),
5252
testval=-1. * np.ones_like(mu1),
5353
transform=None)
54-
like = pm.Deterministic('like', two_gaussians(X))
55-
llk = pm.Potential('like_potential', like)
54+
llk = pm.Potential('muh', two_gaussians(X))
55+
56+
self.step = smc.SMC(
57+
n_chains=self.n_chains,
58+
tune_interval=self.tune_interval,
59+
model=self.ATMIP_test)
5660

5761
self.muref = mu1
5862

59-
@pytest.mark.parametrize('n_jobs', [1, 2])
60-
def test_sample_n_core(self, n_jobs):
63+
@pytest.mark.parametrize(['n_jobs', 'stage'], [[1, 0], [2, 6]])
64+
def test_sample_n_core(self, n_jobs, stage):
6165

6266
def last_sample(x):
6367
return x[(self.n_steps - 1)::self.n_steps]
6468

65-
step = smc.SMC(
66-
n_chains=self.n_chains,
67-
tune_interval=self.tune_interval,
68-
model=self.ATMIP_test,
69-
likelihood_name=self.ATMIP_test.deterministics[0].name)
70-
71-
mtrace = smc.ATMIP_sample(
69+
mtrace = smc.sample_smc(
7270
n_steps=self.n_steps,
73-
step=step,
71+
step=self.step,
72+
stage=stage,
7473
n_jobs=n_jobs,
7574
progressbar=True,
7675
homepath=self.test_folder,

pymc3/tests/test_step.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .checks import close_to
66
from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model, mv_prior_simple
77
from pymc3.sampling import assign_step_methods, sample
8-
from pymc3.model import Model, Deterministic
8+
from pymc3.model import Model
99
from pymc3.step_methods import (NUTS, BinaryGibbsMetropolis, CategoricalGibbsMetropolis,
1010
Metropolis, Slice, CompoundStep, NormalProposal,
1111
MultivariateNormalProposal, HamiltonianMC,
@@ -163,8 +163,7 @@ def check_trace(self, step_method):
163163
with Model():
164164
x = Normal('x', mu=0, sd=1)
165165
if step_method.__name__ == 'SMC':
166-
Deterministic('like', - 0.5 * tt.log(2 * np.pi) - 0.5 * x.T.dot(x))
167-
trace = smc.ATMIP_sample(n_steps=n_steps, step=step_method(random_seed=1),
166+
trace = smc.sample_smc(n_steps=n_steps, step=step_method(random_seed=1),
168167
n_jobs=1, progressbar=False,
169168
homepath=self.temp_dir)
170169
else:

0 commit comments

Comments
 (0)