Skip to content

Commit 2e3caa1

Browse files
authored
Merge pull request #2027 from ferrine/sample_vp
Sample vp for Approximations, deprecate old ADVI
2 parents c0773b7 + 8f99a2c commit 2e3caa1

File tree

13 files changed

+552
-198
lines changed

13 files changed

+552
-198
lines changed

docs/source/api/data.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
*****
1+
****
22
Data
3-
*****
3+
****
44

55
.. currentmodule:: pymc3.data
66

docs/source/api/inference.rst

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,34 @@ Hamiltonian Monte Carlo
5050
Variational
5151
-----------
5252

53-
ADVI
53+
OPVI
5454
^^^^
5555

56-
.. currentmodule:: pymc3.variational.advi
56+
.. currentmodule:: pymc3.variational.opvi
5757

58-
.. automodule:: pymc3.variational.advi
58+
.. automodule:: pymc3.variational.opvi
5959
:members:
6060

61-
ADVI minibatch
62-
^^^^^^^^^^^^^^
61+
Inference
62+
^^^^^^^^^
6363

64-
.. currentmodule:: pymc3.variational.advi_minibatch
64+
.. currentmodule:: pymc3.variational.inference
6565

66-
.. automodule:: pymc3.variational.advi_minibatch
66+
.. automodule:: pymc3.variational.inference
6767
:members:
6868

69-
ADVI approximations
70-
^^^^^^^^^^^^^^^^^^^
69+
Approximations
70+
^^^^^^^^^^^^^^
7171

7272
.. currentmodule:: pymc3.variational.approximations
7373

7474
.. automodule:: pymc3.variational.approximations
7575
:members:
76+
77+
Operators
78+
^^^^^^^^^
79+
80+
.. currentmodule:: pymc3.variational.operators
81+
82+
.. automodule:: pymc3.variational.operators
83+
:members:

pymc3/sampling.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
STEP_METHODS = (NUTS, HamiltonianMC, Metropolis, BinaryMetropolis,
2323
BinaryGibbsMetropolis, Slice, CategoricalGibbsMetropolis)
2424

25+
2526
def assign_step_methods(model, step=None, methods=STEP_METHODS,
2627
step_kwargs=None):
2728
"""Assign model variables to appropriate step methods.
@@ -568,19 +569,29 @@ def init_nuts(init='ADVI', njobs=1, n_init=500000, model=None,
568569
init = init.lower()
569570

570571
if init == 'advi':
571-
v_params = pm.variational.advi(n=n_init, random_seed=random_seed,
572-
progressbar=progressbar)
573-
start = pm.variational.sample_vp(v_params, njobs, progressbar=False,
574-
hide_transformed=False,
575-
random_seed=random_seed)
572+
approx = pm.fit(
573+
seed=random_seed,
574+
n=n_init, method='advi', model=model,
575+
callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-2)],
576+
progressbar=progressbar
577+
) # type: pm.MeanField
578+
start = approx.sample(draws=njobs)
579+
cov = approx.cov.eval()
576580
if njobs == 1:
577581
start = start[0]
578-
cov = np.power(model.dict_to_array(v_params.stds), 2)
579582
elif init == 'advi_map':
580583
start = pm.find_MAP()
581-
v_params = pm.variational.advi(n=n_init, start=start,
582-
random_seed=random_seed)
583-
cov = np.power(model.dict_to_array(v_params.stds), 2)
584+
approx = pm.MeanField(model=model, start=start)
585+
pm.fit(
586+
seed=random_seed,
587+
n=n_init, method=pm.ADVI.from_mean_field(approx),
588+
callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-2)],
589+
progressbar=progressbar
590+
)
591+
start = approx.sample(draws=njobs)
592+
cov = approx.cov.eval()
593+
if njobs == 1:
594+
start = start[0]
584595
elif init == 'map':
585596
start = pm.find_MAP()
586597
cov = pm.find_hessian(point=start)

pymc3/tests/test_variational_inference.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pymc3 import Model, Normal
88
from pymc3.variational import (
99
ADVI, FullRankADVI, SVGD,
10-
Histogram,
10+
Empirical,
1111
fit
1212
)
1313
from pymc3.variational.operators import KL
@@ -59,7 +59,7 @@ def _test_aevb(self):
5959
with model:
6060
inference = self.inference(local_rv={x: (mu, rho)})
6161
approx = inference.fit(3, obj_n_mc=2, obj_optimizer=self.optimizer)
62-
approx.sample_vp(10)
62+
approx.sample(10)
6363
approx.apply_replacements(
6464
y,
6565
more_replacements={x: np.asarray([1, 1], dtype=x.dtype)}
@@ -105,17 +105,17 @@ def test_vars_view_dynamic_size_numpy(self):
105105
x_sampled = app.view(app.random_fn(), 'x')
106106
assert x_sampled.shape == () + model['x'].dshape
107107

108-
def test_sample_vp(self):
108+
def test_sample(self):
109109
n_samples = 100
110110
xs = np.random.binomial(n=1, p=0.2, size=n_samples)
111111
with pm.Model():
112112
p = pm.Beta('p', alpha=1, beta=1)
113113
pm.Binomial('xs', n=1, p=p, observed=xs)
114114
app = self.inference().approx
115-
trace = app.sample_vp(draws=1, hide_transformed=True)
115+
trace = app.sample(draws=1, hide_transformed=True)
116116
assert trace.varnames == ['p']
117117
assert len(trace) == 1
118-
trace = app.sample_vp(draws=10, hide_transformed=False)
118+
trace = app.sample(draws=10, hide_transformed=False)
119119
assert sorted(trace.varnames) == ['p', 'p_logodds_']
120120
assert len(trace) == 10
121121

@@ -145,8 +145,11 @@ def test_optimizer_with_full_data(self):
145145
Normal('x', mu=mu_, sd=sd, observed=data)
146146
inf = self.inference()
147147
inf.fit(10)
148-
approx = inf.fit(self.NITER, obj_optimizer=self.optimizer)
149-
trace = approx.sample_vp(10000)
148+
approx = inf.fit(self.NITER,
149+
obj_optimizer=self.optimizer,
150+
callbacks=
151+
[pm.callbacks.CheckParametersConvergence()])
152+
trace = approx.sample(10000)
150153
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.1)
151154
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
152155

@@ -172,8 +175,10 @@ def create_minibatch(data):
172175
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
173176
Normal('x', mu=mu_, sd=sd, observed=minibatches, total_size=n)
174177
inf = self.inference()
175-
approx = inf.fit(self.NITER * 3, obj_optimizer=self.optimizer)
176-
trace = approx.sample_vp(10000)
178+
approx = inf.fit(self.NITER * 3, obj_optimizer=self.optimizer,
179+
callbacks=
180+
[pm.callbacks.CheckParametersConvergence()])
181+
trace = approx.sample(10000)
177182
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.1)
178183
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
179184

@@ -203,8 +208,10 @@ def cb(*_):
203208
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
204209
Normal('x', mu=mu_, sd=sd, observed=data_t, total_size=n)
205210
inf = self.inference()
206-
approx = inf.fit(self.NITER * 3, callbacks=[cb], obj_n_mc=10, obj_optimizer=self.optimizer)
207-
trace = approx.sample_vp(10000)
211+
approx = inf.fit(self.NITER * 3, callbacks=
212+
[cb, pm.callbacks.CheckParametersConvergence()],
213+
obj_n_mc=10, obj_optimizer=self.optimizer)
214+
trace = approx.sample(10000)
208215
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.4)
209216
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
210217

@@ -274,30 +281,30 @@ class TestSVGD(TestApproximates.Base):
274281
optimizer = functools.partial(pm.adam, learning_rate=.1)
275282

276283

277-
class TestHistogram(SeededTest):
284+
class TestEmpirical(SeededTest):
278285
def test_sampling(self):
279286
with models.multidimensional_model()[1]:
280287
full_rank = FullRankADVI()
281288
approx = full_rank.fit(20)
282-
trace0 = approx.sample_vp(10000)
283-
histogram = Histogram(trace0)
284-
trace1 = histogram.sample_vp(100000)
289+
trace0 = approx.sample(10000)
290+
approx = Empirical(trace0)
291+
trace1 = approx.sample(100000)
285292
np.testing.assert_allclose(trace0['x'].mean(0), trace1['x'].mean(0), atol=0.01)
286293
np.testing.assert_allclose(trace0['x'].var(0), trace1['x'].var(0), atol=0.01)
287294

288-
def test_aevb_histogram(self):
295+
def test_aevb_empirical(self):
289296
_, model, _ = models.exponential_beta(n=2)
290297
x = model.x
291298
mu = theano.shared(x.init_value)
292299
rho = theano.shared(np.zeros_like(x.init_value))
293300
with model:
294301
inference = ADVI(local_rv={x: (mu, rho)})
295302
approx = inference.approx
296-
trace0 = approx.sample_vp(10000)
297-
histogram = Histogram(trace0, local_rv={x: (mu, rho)})
298-
trace1 = histogram.sample_vp(10000)
299-
histogram.random(no_rand=True)
300-
histogram.random_fn(no_rand=True)
303+
trace0 = approx.sample(10000)
304+
approx = Empirical(trace0, local_rv={x: (mu, rho)})
305+
trace1 = approx.sample(10000)
306+
approx.random(no_rand=True)
307+
approx.random_fn(no_rand=True)
301308
np.testing.assert_allclose(trace0['y'].mean(0), trace1['y'].mean(0), atol=0.02)
302309
np.testing.assert_allclose(trace0['y'].var(0), trace1['y'].var(0), atol=0.02)
303310
np.testing.assert_allclose(trace0['x'].mean(0), trace1['x'].mean(0), atol=0.02)
@@ -310,17 +317,17 @@ def test_random_with_transformed(self):
310317
p = pm.Uniform('p')
311318
pm.Bernoulli('trials', p, observed=trials)
312319
trace = pm.sample(1000, step=pm.Metropolis())
313-
histogram = Histogram(trace)
314-
histogram.randidx(None).eval()
315-
histogram.randidx(1).eval()
316-
histogram.random_fn(no_rand=True)
317-
histogram.random_fn(no_rand=False)
318-
histogram.histogram_logp.eval()
320+
approx = Empirical(trace)
321+
approx.randidx(None).eval()
322+
approx.randidx(1).eval()
323+
approx.random_fn(no_rand=True)
324+
approx.random_fn(no_rand=False)
325+
approx.histogram_logp.eval()
319326

320327
def test_init_from_noize(self):
321328
with models.multidimensional_model()[1]:
322-
histogram = Histogram.from_noise(100)
323-
assert histogram.histogram.eval().shape == (100, 6)
329+
approx = Empirical.from_noise(100)
330+
assert approx.histogram.eval().shape == (100, 6)
324331

325332
_model = models.simple_model()[1]
326333
with _model:

pymc3/theanof.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,17 +372,28 @@ def launch_rng(rng):
372372
launch_rng(_tt_rng)
373373

374374

375-
def tt_rng():
375+
def tt_rng(seed=None):
376376
"""
377-
Get the package-level random number generator.
377+
Get the package-level random number generator or new with specified seed.
378+
379+
Parameters
380+
----------
381+
seed : int
382+
If not None
383+
returns *new* theano random generator without replacing package global one
378384
379385
Returns
380386
-------
381387
`theano.sandbox.rng_mrg.MRG_RandomStreams` instance
382388
`theano.sandbox.rng_mrg.MRG_RandomStreams`
383389
instance passed to the most recent call of `set_tt_rng`
384390
"""
385-
return _tt_rng
391+
if seed is None:
392+
return _tt_rng
393+
else:
394+
ret = MRG_RandomStreams(seed)
395+
launch_rng(ret)
396+
return ret
386397

387398

388399
def set_tt_rng(new_rng):

pymc3/variational/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
fit,
2424
)
2525
from .approximations import (
26-
Histogram,
26+
Empirical,
2727
FullRank,
28-
MeanField
28+
MeanField,
29+
sample_approx
2930
)
3031

3132
from . import approximations
@@ -34,3 +35,4 @@
3435
from . import opvi
3536
from . import updates
3637
from . import inference
38+
from . import callbacks

pymc3/variational/advi.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False,
108108
and Blei, D. M. (2016). Automatic Differentiation Variational
109109
Inference. arXiv preprint arXiv:1603.00788.
110110
"""
111+
import warnings
112+
warnings.warn('Old ADVI interface and sample_vp is deprecated and will '
113+
'be removed in future, use pm.fit and pm.sample_approx instead',
114+
DeprecationWarning, stacklevel=2)
111115
model = pm.modelcontext(model)
112116
if start is None:
113117
start = model.test_point
@@ -357,6 +361,10 @@ def sample_vp(
357361
trace : pymc3.backends.base.MultiTrace
358362
Samples drawn from the variational posterior.
359363
"""
364+
import warnings
365+
warnings.warn('Old ADVI interface and sample_vp is deprecated and will '
366+
'be removed in future, use pm.fit and pm.sample_approx instead',
367+
DeprecationWarning, stacklevel=2)
360368
model = pm.modelcontext(model)
361369

362370
if isinstance(vparams, ADVIFit):

pymc3/variational/advi_minibatch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,9 @@ def advi_minibatch(vars=None, start=None, model=None, n=5000, n_mcsamples=1,
436436
Weight Uncertainty in Neural Network. In Proceedings of the 32nd
437437
International Conference on Machine Learning (ICML-15) (pp. 1613-1622).
438438
"""
439+
import warnings
440+
warnings.warn('Old ADVI interface is deprecated and be removed in future, use pm.ADVI instead',
441+
DeprecationWarning, stacklevel=2)
439442
if encoder_params is None:
440443
encoder_params = []
441444

0 commit comments

Comments
 (0)