Skip to content

Commit c8ef43a

Browse files
ferrinetaku-y
authored andcommitted
refactored tests for variational inference
1 parent 82a732e commit c8ef43a

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,15 @@ class TestApproximates:
6969
class Base(SeededTest):
7070
inference = None
7171
NITER = 12000
72-
optimizer = functools.partial(pm.adam, learning_rate=.01)
72+
optimizer = pm.adagrad_window(learning_rate=0.01)
73+
conv_cb = property(lambda self: [
74+
pm.callbacks.CheckParametersConvergence(
75+
every=500,
76+
diff='relative', tolerance=0.001),
77+
pm.callbacks.CheckParametersConvergence(
78+
every=500,
79+
diff='absolute', tolerance=0.0001)
80+
])
7381

7482
def test_vars_view(self):
7583
_, model, _ = models.multidimensional_model()
@@ -146,11 +154,10 @@ def test_optimizer_with_full_data(self):
146154
inf.fit(10)
147155
approx = inf.fit(self.NITER,
148156
obj_optimizer=self.optimizer,
149-
callbacks=
150-
[pm.callbacks.CheckParametersConvergence()],)
157+
callbacks=self.conv_cb,)
151158
trace = approx.sample(10000)
152-
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.1)
153-
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
159+
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.05)
160+
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.1)
154161

155162
def test_optimizer_minibatch_with_generator(self):
156163
n = 1000
@@ -175,11 +182,10 @@ def create_minibatch(data):
175182
Normal('x', mu=mu_, sd=sd, observed=minibatches, total_size=n)
176183
inf = self.inference()
177184
approx = inf.fit(self.NITER * 3, obj_optimizer=self.optimizer,
178-
callbacks=
179-
[pm.callbacks.CheckParametersConvergence()])
185+
callbacks=self.conv_cb)
180186
trace = approx.sample(10000)
181-
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.1)
182-
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
187+
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.05)
188+
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.1)
183189

184190
def test_optimizer_minibatch_with_callback(self):
185191
n = 1000
@@ -208,12 +214,20 @@ def cb(*_):
208214
Normal('x', mu=mu_, sd=sd, observed=data_t, total_size=n)
209215
inf = self.inference(scale_cost_to_minibatch=True)
210216
approx = inf.fit(
211-
self.NITER * 3, callbacks=[
212-
cb, pm.callbacks.CheckParametersConvergence()
213-
], obj_n_mc=10, obj_optimizer=self.optimizer)
217+
self.NITER * 3, callbacks=[cb] + self.conv_cb, obj_optimizer=self.optimizer)
214218
trace = approx.sample(10000)
215-
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.4)
216-
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
219+
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.05)
220+
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.1)
221+
222+
def test_n_obj_mc(self):
223+
n_samples = 100
224+
xs = np.random.binomial(n=1, p=0.2, size=n_samples)
225+
with pm.Model():
226+
p = pm.Beta('p', alpha=1, beta=1)
227+
pm.Binomial('xs', n=1, p=p, observed=xs)
228+
inf = self.inference(scale_cost_to_minibatch=True)
229+
# should just work
230+
inf.fit(10, obj_n_mc=10, obj_optimizer=self.optimizer)
217231

218232
def test_pickling(self):
219233
with models.multidimensional_model()[1]:
@@ -277,15 +291,14 @@ def test_from_advi(self):
277291

278292
class TestSVGD(TestApproximates.Base):
279293
inference = functools.partial(SVGD, n_particles=100)
280-
NITER = 2500
281-
optimizer = functools.partial(pm.adam, learning_rate=.1)
282294

283295

284296
class TestASVGD(TestApproximates.Base):
297+
NITER = 15000
285298
inference = ASVGD
286-
NITER = 4000
287-
optimizer = functools.partial(pm.adam, learning_rate=.05)
288299
test_aevb = _test_aevb
300+
optimizer = pm.adagrad_window(learning_rate=0.001)
301+
conv_cb = []
289302

290303

291304
class TestEmpirical(SeededTest):

0 commit comments

Comments
 (0)