Skip to content

Commit 0f45720

Browse files
authored
refactor variational module, add histogram approximation (#1904)
* refactor module, add histogram * add more tests * refactor some code concerning AEVB histogram * fix test for histogram * use mean as deterministic point in Histogram * remove unused import * change names of shortcuts * add names to shared params * add new line at the end of `approximations.py`
1 parent 2cce1c3 commit 0f45720

File tree

6 files changed

+456
-257
lines changed

6 files changed

+456
-257
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from six.moves import cPickle as pickle
1+
import pickle
22
import unittest
33
import numpy as np
44
from theano import theano, tensor as tt
55
import pymc3 as pm
66
from pymc3 import Model, Normal
7-
from pymc3.variational.inference import (
8-
KL, MeanField, ADVI, FullRankADVI,
7+
from pymc3.variational import (
8+
ADVI, FullRankADVI,
9+
Histogram,
910
fit
1011
)
12+
from pymc3.variational.operators import KL
13+
from pymc3.variational.approximations import MeanField
1114

1215
from pymc3.tests import models
1316
from pymc3.tests.helpers import SeededTest
@@ -186,7 +189,7 @@ def cb(*_):
186189
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
187190
Normal('x', mu=mu_, sd=sd, observed=data_t, total_size=n)
188191
inf = self.inference()
189-
approx = inf.fit(self.NITER, callbacks=[cb])
192+
approx = inf.fit(self.NITER, callbacks=[cb], obj_n_mc=10)
190193
trace = approx.sample_vp(10000)
191194
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.4)
192195
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
@@ -199,22 +202,34 @@ def test_pickling(self):
199202
inference.fit(20)
200203

201204
def test_aevb(self):
202-
_, model, _ = models.exponential_beta()
205+
_, model, _ = models.exponential_beta(n=2)
203206
x = model.x
204207
y = model.y
205208
mu = theano.shared(x.init_value) * 2
206-
sd = theano.shared(x.init_value) * 3
209+
rho = theano.shared(np.zeros_like(x.init_value))
207210
with model:
208-
inference = self.inference(local_rv={y: (mu, sd)})
209-
inference.fit(3)
211+
inference = self.inference(local_rv={y: (mu, rho)})
212+
approx = inference.fit(3, obj_n_mc=2)
213+
approx.sample_vp(10)
214+
approx.apply_replacements(
215+
y,
216+
more_replacements={x: np.asarray([1, 1], dtype=x.dtype)}
217+
).eval()
218+
219+
def test_profile(self):
220+
with models.multidimensional_model()[1]:
221+
self.inference().run_profiling(10)
210222

211223

212224
class TestMeanField(TestApproximates.Base):
213225
inference = ADVI
214226

215227
def test_approximate(self):
216228
with models.multidimensional_model()[1]:
217-
fit(10, method='advi')
229+
meth = ADVI()
230+
fit(10, method=meth)
231+
self.assertRaises(KeyError, fit, 10, method='undefined')
232+
self.assertRaises(TypeError, fit, 10, method=1)
218233

219234

220235
class TestFullRank(TestApproximates.Base):
@@ -234,11 +249,54 @@ def test_from_advi(self):
234249

235250
def test_combined(self):
236251
with models.multidimensional_model()[1]:
252+
self.assertRaises(ValueError, fit, 10, method='advi->fullrank_advi', frac=1)
237253
fit(10, method='advi->fullrank_advi', frac=.5)
238254

239255
def test_approximate(self):
240256
with models.multidimensional_model()[1]:
241257
fit(10, method='fullrank_advi')
242258

259+
260+
class TestHistogram(SeededTest):
261+
def test_sampling(self):
262+
with models.multidimensional_model()[1]:
263+
full_rank = FullRankADVI()
264+
approx = full_rank.fit(20)
265+
trace0 = approx.sample_vp(10000)
266+
histogram = Histogram(trace0)
267+
trace1 = histogram.sample_vp(100000)
268+
np.testing.assert_allclose(trace0['x'].mean(0), trace1['x'].mean(0), atol=0.01)
269+
np.testing.assert_allclose(trace0['x'].var(0), trace1['x'].var(0), atol=0.01)
270+
271+
def test_aevb_histogram(self):
272+
_, model, _ = models.exponential_beta(n=2)
273+
x = model.x
274+
mu = theano.shared(x.init_value)
275+
rho = theano.shared(np.zeros_like(x.init_value))
276+
with model:
277+
inference = ADVI(local_rv={x: (mu, rho)})
278+
approx = inference.approx
279+
trace0 = approx.sample_vp(1000)
280+
histogram = Histogram(trace0, local_rv={x: (mu, rho)})
281+
trace1 = histogram.sample_vp(10000)
282+
histogram.random(no_rand=True)
283+
histogram.random_fn(no_rand=True)
284+
np.testing.assert_allclose(trace0['y'].mean(0), trace1['y'].mean(0), atol=0.02)
285+
np.testing.assert_allclose(trace0['y'].var(0), trace1['y'].var(0), atol=0.02)
286+
np.testing.assert_allclose(trace0['x'].mean(0), trace1['x'].mean(0), atol=0.02)
287+
np.testing.assert_allclose(trace0['x'].var(0), trace1['x'].var(0), atol=0.02)
288+
289+
def test_random(self):
290+
with models.multidimensional_model()[1]:
291+
full_rank = FullRankADVI()
292+
approx = full_rank.approx
293+
trace0 = approx.sample_vp(10000)
294+
histogram = Histogram(trace0)
295+
histogram.randidx(None).eval()
296+
histogram.randidx(1).eval()
297+
histogram.random_fn(no_rand=True)
298+
histogram.random_fn(no_rand=False)
299+
histogram.histogram_logp.eval()
300+
243301
if __name__ == '__main__':
244302
unittest.main()

pymc3/variational/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@
2323
FullRankADVI,
2424
fit,
2525
)
26+
from .approximations import Histogram
27+
28+
from . import approximations
29+
from . import operators

0 commit comments

Comments
 (0)