Skip to content

Refactor SMC and properly compute marginal likelihood #3124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 18, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 61 additions & 25 deletions pymc3/step_methods/smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .arraystep import metrop_select
from .metropolis import MultivariateNormalProposal
from ..theanof import floatX
from ..theanof import floatX, inputvars, make_shared_replacements, join_nonshared_inputs
from ..model import modelcontext
from ..backends.ndarray import NDArray
from ..backends.base import MultiTrace
Expand Down Expand Up @@ -100,18 +100,19 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
discrete = np.concatenate([[v.dtype in pm.discrete_types] * (v.dsize or 1) for v in variables])
any_discrete = discrete.any()
all_discrete = discrete.all()
prior_logp = theano.function(model.vars, model.varlogpt)
likelihood_logp = theano.function(model.vars, model.datalogpt)
shared = make_shared_replacements(variables, model)
prior_logp = logp_forw([model.varlogpt], variables, shared)
likelihood_logp = logp_forw([model.datalogpt], variables, shared)

pm._log.info('Sample initial stage: ...')
posterior = _initial_population(draws, model, variables)
posterior, var_info = _initial_population(draws, model, variables)

while beta < 1:
# compute plausibility weights (measure fitness)
likelihoods = np.array([likelihood_logp(*sample) for sample in posterior])
likelihoods = np.array([likelihood_logp(sample) for sample in posterior]).squeeze()
beta, old_beta, weights, sj = _calc_beta(beta, likelihoods, step.threshold)
model.marginal_likelihood *= sj
pm._log.info('Beta: {:f} Stage: {:d}'.format(beta, stage))

# resample based on plausibility weights (selection)
resampling_indexes = np.random.choice(np.arange(draws), size=draws, p=weights)
posterior = posterior[resampling_indexes]
Expand All @@ -132,7 +133,7 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
# Apply Metropolis kernel (mutation)
proposed = 0.
accepted = 0.
priors = np.array([prior_logp(*sample) for sample in posterior])
priors = np.array([prior_logp(sample) for sample in posterior]).squeeze()
tempered_post = priors + likelihoods * beta
for draw in tqdm(range(draws), disable=not progressbar):
old_tempered_post = tempered_post[draw]
Expand All @@ -152,7 +153,7 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
else:
q_new = floatX(q_old + delta)

new_tempered_post = prior_logp(*q_new) + likelihood_logp(*q_new) * beta
new_tempered_post = prior_logp(q_new) + likelihood_logp(q_new)[0] * beta

q_old, accept = metrop_select(new_tempered_post - old_tempered_post, q_new, q_old)
if accept:
Expand All @@ -164,26 +165,32 @@ def sample_smc(draws=5000, step=None, progressbar=False, model=None, random_seed
acc_rate = accepted / proposed
stage += 1

trace = _posterior_to_trace(posterior, model)
trace = _posterior_to_trace(posterior, model, var_info)

return trace

# FIXME!!!!
def _initial_population(samples, model, variables):

def _initial_population(chains, model, variables):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not using sample_prior_predictive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really, I will change it.

"""
Create an initial population from the prior
"""
population = np.zeros((samples, len(variables)))
population = []
init_rnd = {}
start = model.test_point
for idx, v in enumerate(variables):
var_info = {}
for v in variables:
if pm.util.is_transformed_name(v.name):
trans = v.distribution.transform_used.forward_val
population[:,idx] = trans(v.distribution.dist.random(size=samples, point=start))
init_rnd[v.name] = trans(v.distribution.dist.random(size=chains, point=start))
else:
population[:,idx] = v.random(size=samples, point=start)
init_rnd[v.name] = v.random(size=chains, point=start)
var_info[v.name] = (start[v.name].shape, start[v.name].size)

for i in range(chains):
point = pm.Point({v.name: init_rnd[v.name][i] for v in variables}, model=model)
population.append(model.dict_to_array(point))

return population
return np.array(population), var_info


def _calc_beta(beta, likelihoods, threshold=0.5):
Expand All @@ -204,12 +211,14 @@ def _calc_beta(beta, likelihoods, threshold=0.5):

Returns
-------
beta : float
new_beta : float
tempering parameter of the next stage
beta : float
old_beta : float
tempering parameter of the current stage
weights : numpy array
Importance weights (floats)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need update (beta also appear twice above)

sj : float
Partial marginal likelihood
"""
low_beta = old_beta = beta
up_beta = 2.
Expand All @@ -228,10 +237,10 @@ def _calc_beta(beta, likelihoods, threshold=0.5):
low_beta = new_beta
if new_beta >= 1:
new_beta = 1
lala = np.exp((new_beta - old_beta) * likelihoods)
sj = np.exp((new_beta - old_beta) * likelihoods)
weights_un = np.exp((new_beta - old_beta) * (likelihoods - likelihoods.max()))
weights = weights_un / np.sum(weights_un)
return new_beta, old_beta, weights, np.mean(lala)
return new_beta, old_beta, weights, np.mean(sj)


def _calc_covariance(posterior_array, weights):
Expand All @@ -243,6 +252,7 @@ def _calc_covariance(posterior_array, weights):
raise ValueError('Sample covariances not valid! Likely "chains" is too small!')
return np.atleast_2d(cov)


def _tune(acc_rate):
"""
Tune adaptively based on the acceptance rate.
Expand All @@ -261,15 +271,41 @@ def _tune(acc_rate):
b = 8. / 9
return (a + b * acc_rate) ** 2

def _posterior_to_trace(posterior, model):

def _posterior_to_trace(posterior, model, var_info):
"""
Save results into a PyMC3 trace
"""
length_pos = len(posterior)
lenght_pos = len(posterior)
varnames = [v.name for v in model.vars]

with model:
strace = NDArray(model)
strace.setup(length_pos, 0)
for i in range(length_pos):
strace.record({k:v for k, v in zip(varnames, posterior[i])})
strace.setup(lenght_pos, 0)
for i in range(lenght_pos):
value = []
size = 0
for var in varnames:
shape, new_size = var_info[var]
value.append(posterior[i][size:size+new_size].reshape(shape))
size += new_size
strace.record({k: v for k, v in zip(varnames, value)})
return MultiTrace([strace])


def logp_forw(out_vars, vars, shared):
"""Compile Theano function of the model and the input and output variables.

Parameters
----------
out_vars : List
containing :class:`pymc3.Distribution` for the output variables
vars : List
containing :class:`pymc3.Distribution` for the input variables
shared : List
containing :class:`theano.tensor.Tensor` for depended shared data
"""
out_list, inarray0 = join_nonshared_inputs(out_vars, vars, shared)
f = theano.function([inarray0], out_list)
f.trust_input = True
return f
56 changes: 19 additions & 37 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
import pymc3 as pm
import numpy as np
from pymc3.backends.smc_text import TextStage
import pytest
from tempfile import mkdtemp
import shutil
import theano.tensor as tt
import theano

from .helpers import SeededTest


@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
class TestSMC(SeededTest):

def setup_class(self):
super(TestSMC, self).setup_class()
self.test_folder = mkdtemp(prefix='ATMIP_TEST')

self.samples = 1500
self.chains = 200
self.samples = 1000
n = 4
mu1 = np.ones(n) * (1. / 2)
mu2 = - mu1
Expand All @@ -40,44 +31,35 @@ def two_gaussians(x):
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
return tt.log(w1 * tt.exp(log_like1) + w2 * tt.exp(log_like2))

with pm.Model() as self.ATMIP_test:
with pm.Model() as self.SMC_test:
X = pm.Uniform('X', lower=-2, upper=2., shape=n)
llk = pm.Potential('muh', two_gaussians(X))

self.muref = mu1


@pytest.mark.parametrize(['cores', 'stage'], [[1, 0], [2, 6]])
def test_sample_n_core(self, cores, stage):
step_kwargs = {'homepath': self.test_folder, 'stage': stage}
with self.ATMIP_test:
def test_sample(self):
with self.SMC_test:
mtrace = pm.sample(draws=self.samples,
chains=self.chains,
cores=cores,
step = pm.SMC(),
step_kwargs=step_kwargs)
step = pm.SMC())

x = mtrace.get_values('X')
x = mtrace['X']
mu1d = np.abs(x).mean(axis=0)
np.testing.assert_allclose(self.muref, mu1d, rtol=0., atol=0.03)
# Scenario IV Ching, J. & Chen, Y. 2007
#assert np.round(np.log(self.ATMIP_test.marginal_likelihood)) == -12.0

def test_stage_handler(self):
stage_number = -1
stage_handler = TextStage(self.test_folder)

step = stage_handler.load_atmip_params(stage_number, model=self.ATMIP_test)
assert step.stage == stage_number
def test_ml(self):
data = np.repeat([1, 0], [50, 50])
marginals = []
a_prior_0, b_prior_0 = 1., 1.
a_prior_1, b_prior_1 = 20., 20.

corrupted_chains = stage_handler.recover_existing_results(stage_number,
self.samples / self.chains,
self.chains,
step,
model=self.ATMIP_test)
assert len(corrupted_chains) == self.chains
for alpha, beta in ((a_prior_0, b_prior_0), (a_prior_1, b_prior_1)):
with pm.Model() as model:
a = pm.Beta('a', alpha, beta)
y = pm.Bernoulli('y', a, observed=data)
trace = pm.sample(2000, step=pm.SMC())
marginals.append(model.marginal_likelihood)
# compare to the analytical result
assert np.floor(marginals[1] / marginals[0]) == 4.0

rtrace = stage_handler.load_result_trace(model=self.ATMIP_test)

def teardown_class(self):
shutil.rmtree(self.test_folder)
80 changes: 42 additions & 38 deletions pymc3/tests/test_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,40 +124,46 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
-2.24238542e+00, -1.01648100e+00, -1.01648100e+00, -7.60912865e-01,
1.44384812e+00, 2.07355127e+00, 1.91390340e+00, 1.66559696e+00]),
smc.SMC: np.array(
[ 0.61562138, -0.56082978, -0.89760381, 1.47368457, 0.33300527, 0.85567605,
-1.33503519, -1.47996682, -0.3725601, 0.75713321, 1.81055917, 0.39193534,
0.10083821, 0.55569412, -0.65879812, -0.61545061, -2.65522875, 0.93801687,
2.40499211, -0.63022535, 0.09565784, -1.00650846, 1.65901231, 0.18429996,
1.64642521, 0.5589963, -0.40452525, -0.9402324, 0.53813986, 0.55785946,
1.22966132, 0.2782562, -0.81254158, -0.08076293, -0.29136329, 0.62914226,
0.16049388, -0.06386387, 1.8103961, -0.98444811, -0.36333739, 0.88703339,
-0.08482673, -0.23224262, -0.11348807, 1.09401682, -0.58594449, -0.12728503,
-0.82408778, -1.82770764, -2.28859404, -0.51814943, -1.53652851, 0.66313366,
1.61666698, 1.41284768, -0.05129251, 0.96166282, 1.00446145, -0.86380886,
-1.13410885, -0.48311125, -1.25446622, -0.48452779, -0.84647195, -0.43201729,
-1.22186151, 1.18698485, 0.33434434, -0.40650775, 0.47740064, 0.96943022,
1.15534028, -0.86220564, -0.26049285, -1.17489183, 0.66796904, -1.68920203,
-0.96308602, -1.73542483, -0.84744376, 0.91864514, -0.02724505, 0.16143404,
0.65747707, -1.49655923, -0.32010575, 1.20255401, 0.1203948, -1.30017822,
1.55895643, -0.74799042, -1.5938872, 0.69297144, -1.32932843, -0.16886992,
-1.01437109, 0.32476589, 1.02509164, 0.31274278, -0.7908909, 1.18439217,
-0.96132492, -0.4934065, 0.71438293, 0.09829997, 1.81936381, 0.47941016,
0.3717936, 0.14339747, 1.24288736, 0.92520773, 0.69025067, 0.96618094,
0.69085402, -1.12128175, 0.11228076, 0.7711306, 0.12859226, 0.65792466,
-0.07422313, 1.74736739, 0.24120104, 0.74946338, 0.66260928, -0.34070258,
1.09875434, -0.4159233, -0.01607339, 1.20921296, -0.29176047, 0.47367867,
-1.45788116, -0.40198772, 0.44502909, 0.65623719, 0.99422221, 1.37241668,
-0.05163759, 0.82729935, 0.59458429, 1.10870872, -1.00730291, -0.07837131,
-0.28144688, -0.03052015, 1.05263496, 0.19011829, -0.98807301, -0.77388355,
-1.68729554, 0.03018351, 0.39424573, 0.98343413, -1.40600196, 1.19764243,
1.64712279, 0.68929684, -0.54301669, -0.29369924, 0.09052877, 2.64067523,
-1.25887138, 1.65991714, 0.71271397, -0.50396329, 1.2182173, 0.2472108,
-0.2990774, 0.1646579, 0.21418971, -0.0876372, 0.66714317, -0.43490764,
-2.17899663, -0.2681325, -3.10431098, -1.38211864, 0.02041712, 0.16319981,
-1.02526047, 1.93088335, -0.36975507, -0.61332039, 0.33666881, -0.23766903,
-0.58478679, 1.38941035, -0.45829187, -1.12505096, -1.4814355, 0.61790977,
0.58867984, 1.38693864, 1.80845772, -1.63246225, -1.48247172, -0.69197631,
0.65045375, -0.09601979]),
np.array([1.41117913, 1.64786848, 0.91722731, 1.45389228, 1.02451573,
0.85798363, 1.09617213, 1.91933133, 1.38944922, 1.28784728,
1.69916542, 1.2740302 , 1.98886485, 1.69370475, 1.61759217,
1.26563918, 0.58791742, 1.66085807, 1.31776859, 1.57789075,
2.30023319, 0.82982445, 1.2177862 , 0.99787145, 1.05348682,
1.15775351, 0.54392086, 2.11077821, 2.47461004, 1.73303454,
1.01737162, 1.83675088, 1.51428954, 0.52282184, 2.37788098,
1.04627241, 2.02728668, 2.07908118, 1.863258 , 1.50222989,
1.67973843, 2.16871644, 1.54349527, 1.33198955, 1.8797815 ,
1.43952095, 1.56282684, 2.28026637, 1.24773664, 1.28251139,
2.65020285, 1.62000451, 1.0821298 , 2.20458889, 1.56753094,
2.04763651, 1.28639926, 1.68799005, 1.4186754 , 0.7952981 ,
0.9703601 , 1.236214 , 1.63472278, 1.98110675, 1.26566976,
2.1382887 , 1.11910639, 1.00691799, 1.14989 , 0.98631041,
1.59265667, 1.42570867, 1.50940687, 0.80387284, 1.69967565,
1.66695801, 1.25114565, 0.907412 , 1.66828172, 1.85457132,
1.20480774, 2.2224195 , 1.3798713 , 2.77890671, 1.67045077,
1.84703256, 0.9840681 , 1.50819703, 1.79269974, 1.34110018,
1.54102471, 0.53134127, 1.31140629, 1.07681732, 2.05505094,
2.27099409, 1.64648775, 1.59317554, 0.74550987, 1.23768242,
0.57339393, 1.72719298, 1.34877586, 1.40648412, 1.36633963,
2.13211623, 0.3422404 , 1.80987189, 1.17095936, 2.22665412,
2.14976788, 1.66844646, 1.40758582, 1.31435313, 1.45675102,
1.27374917, 1.78272082, 1.70903882, 1.64561402, 1.54613473,
1.58155217, 0.41347197, 0.2454103 , 1.18388892, 1.43178759,
1.29884578, 0.98116748, 1.43191455, 0.72443333, 1.55066915,
1.15537114, 1.51085638, 1.79389202, 1.33979587, 0.73739759,
2.16365184, 1.00512435, 0.54543314, 1.93690273, 0.49008933,
1.82693925, 1.44985792, 2.04549251, 1.38125344, 1.1832728 ,
1.02750196, 1.50665663, 1.62350414, 2.16769179, 1.20632128,
1.24008984, 1.65649161, 1.18786854, 1.89970735, 1.27612317,
0.92133366, 2.20694733, 1.24515401, 1.8066277 , 0.73046907,
1.65105233, 0.77993076, 1.23035535, 0.91091633, 2.42155306,
1.69911773, 1.47576401, 1.28133376, 2.15270933, 1.49045131,
0.97506262, 2.03833814, 0.99020804, 1.57680544, 1.71924419,
0.68686093, 1.51160641, 1.44976239, 1.18024189, 1.23250215,
2.36732805, 1.26811905, 1.75390284, 1.62451338, 0.71297471,
0.76569102, 1.79209952, 2.24259697, 1.34634267, 0.86080879,
1.59529186, 2.05688065, 1.27474796, 1.20044171, 1.98148641,
0.49048629, 1.89423563, 1.01198881, 1.00667743, 1.03871664])),
}

def setup_class(self):
Expand Down Expand Up @@ -190,14 +196,12 @@ def check_trace(self, step_method):
n_steps = 100
with Model() as model:
x = Normal('x', mu=0, sd=1)
y = Normal('y', mu=x, sd=1, observed=[1, 2, 3])
if step_method.__name__ == 'SMC':
trace = sample(draws=200,
chains=2,
start=[{'x':1.}, {'x':-1.}],
random_seed=1,
progressbar=False,
step=step_method(),
step_kwargs={'homepath': self.temp_dir})
step=step_method())
elif step_method.__name__ == 'NUTS':
step = step_method(scaling=model.test_point)
trace = sample(0, tune=n_steps,
Expand Down