Skip to content

Commit 2b48b94

Browse files
ferrinetwiecki
authored andcommitted
but fix devoted to #2109 but for VI
1 parent 27cd067 commit 2b48b94

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ def test_optimizer_with_full_data(self):
143143
with Model():
144144
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
145145
Normal('x', mu=mu_, sd=sd, observed=data)
146-
inf = self.inference()
146+
inf = self.inference(start={})
147147
inf.fit(10)
148148
approx = inf.fit(self.NITER,
149149
obj_optimizer=self.optimizer,
150150
callbacks=
151-
[pm.callbacks.CheckParametersConvergence()])
151+
[pm.callbacks.CheckParametersConvergence()],)
152152
trace = approx.sample(10000)
153153
np.testing.assert_allclose(np.mean(trace['mu']), mu_post, rtol=0.1)
154154
np.testing.assert_allclose(np.std(trace['mu']), np.sqrt(1. / d), rtol=0.4)
@@ -342,14 +342,15 @@ def test_init_from_noize(self):
342342
[
343343
('undefined', dict(), KeyError),
344344
(1, dict(), TypeError),
345-
(_advi, dict(), None),
345+
(_advi, dict(start={}), None),
346346
(_fullrank_advi, dict(), None),
347347
(_svgd, dict(), None),
348348
('advi', dict(), None),
349349
('advi->fullrank_advi', dict(frac=.1), None),
350350
('advi->fullrank_advi', dict(frac=1), ValueError),
351351
('fullrank_advi', dict(), None),
352352
('svgd', dict(), None),
353+
('svgd', dict(start={}), None),
353354
('svgd', dict(local_rv={_model.free_RVs[0]: (0, 1)}), ValueError)
354355
]
355356
)

pymc3/variational/approximations.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,14 @@ def cov(self):
6464
return tt.diag(rho2sd(self.rho)**2)
6565

6666
def create_shared_params(self, **kwargs):
67-
start = self.gbij.map(kwargs.get('start', self.model.test_point))
67+
start = kwargs.get('start')
68+
if start is None:
69+
start = self.model.test_point
70+
else:
71+
start_ = self.model.test_point.copy()
72+
pm.sampling._update_start_vals(start_, start, self.model)
73+
start = start_
74+
start = self.gbij.map(start)
6875
return {'mu': theano.shared(
6976
pm.floatX(start),
7077
'mu'),
@@ -125,11 +132,13 @@ class FullRank(Approximation):
125132
Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI
126133
approximateinference.org/accepted/RoederEtAl2016.pdf
127134
"""
128-
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, gpu_compat=False, seed=None):
135+
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
136+
gpu_compat=False, seed=None, **kwargs):
129137
super(FullRank, self).__init__(
130138
local_rv=local_rv, model=model,
131139
cost_part_grad_scale=cost_part_grad_scale,
132-
seed=seed
140+
seed=seed,
141+
**kwargs
133142
)
134143
self.gpu_compat = gpu_compat
135144

@@ -161,7 +170,14 @@ def tril_index_matrix(self):
161170
return tril_index_matrix
162171

163172
def create_shared_params(self, **kwargs):
164-
start = self.gbij.map(kwargs.get('start', self.model.test_point))
173+
start = kwargs.get('start')
174+
if start is None:
175+
start = self.model.test_point
176+
else:
177+
start_ = self.model.test_point.copy()
178+
pm.sampling._update_start_vals(start_, start, self.model)
179+
start = start_
180+
start = self.gbij.map(start)
165181
n = self.global_size
166182
L_tril = (
167183
np.eye(n)
@@ -254,8 +270,11 @@ class Empirical(Approximation):
254270
... trace = sample(1000, step=step)
255271
... histogram = Empirical(trace[100:])
256272
"""
257-
def __init__(self, trace, local_rv=None, model=None, seed=None):
258-
super(Empirical, self).__init__(local_rv=local_rv, model=model, trace=trace, seed=seed)
273+
def __init__(self, trace, local_rv=None, model=None, seed=None, **kwargs):
274+
super(Empirical, self).__init__(
275+
local_rv=local_rv, model=model, trace=trace, seed=seed,
276+
**kwargs
277+
)
259278

260279
def check_model(self, model, **kwargs):
261280
trace = kwargs.get('trace')
@@ -355,6 +374,10 @@ def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, see
355374
hist = cls(None, local_rv=local_rv, model=model, seed=seed)
356375
if start is None:
357376
start = hist.model.test_point
377+
else:
378+
start_ = hist.model.test_point.copy()
379+
pm.sampling._update_start_vals(start_, start, hist.model)
380+
start = start_
358381
start = hist.gbij.map(start)
359382
# Initialize particles
360383
x0 = np.tile(start, (size, 1))

pymc3/variational/inference.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ class ADVI(Inference):
306306
Yuhuai Wu, David Duvenaud, 2016) for details
307307
seed : None or int
308308
leave None to use package global RandomStream or other
309-
valid value to create instance specific one
309+
valid value to create instance specific one
310+
start : Point
311+
starting point for inference
310312
311313
References
312314
----------
@@ -321,10 +323,12 @@ class ADVI(Inference):
321323
- Kingma, D. P., & Welling, M. (2014).
322324
Auto-Encoding Variational Bayes. stat, 1050, 1.
323325
"""
324-
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, seed=None):
326+
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
327+
seed=None, start=None):
325328
super(ADVI, self).__init__(
326329
KL, MeanField, None,
327-
local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale, seed=seed)
330+
local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale,
331+
seed=seed, start=start)
328332

329333
@classmethod
330334
def from_mean_field(cls, mean_field):
@@ -372,6 +376,8 @@ class FullRankADVI(Inference):
372376
seed : None or int
373377
leave None to use package global RandomStream or other
374378
valid value to create instance specific one
379+
start : Point
380+
starting point for inference
375381
376382
References
377383
----------
@@ -386,11 +392,12 @@ class FullRankADVI(Inference):
386392
- Kingma, D. P., & Welling, M. (2014).
387393
Auto-Encoding Variational Bayes. stat, 1050, 1.
388394
"""
389-
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1, gpu_compat=False, seed=None):
395+
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
396+
gpu_compat=False, seed=None, start=None):
390397
super(FullRankADVI, self).__init__(
391398
KL, FullRank, None,
392399
local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale,
393-
gpu_compat=gpu_compat, seed=seed)
400+
gpu_compat=gpu_compat, seed=seed, start=start)
394401

395402
@classmethod
396403
def from_full_rank(cls, full_rank):
@@ -497,6 +504,8 @@ class SVGD(Inference):
497504
seed : None or int
498505
leave None to use package global RandomStream or other
499506
valid value to create instance specific one
507+
start : Point
508+
starting point for inference
500509
501510
References
502511
----------
@@ -515,7 +524,7 @@ def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_function
515524
model=model, seed=seed)
516525

517526

518-
def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
527+
def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, start=None, **kwargs):
519528
"""
520529
Handy shortcut for using inference methods in functional way
521530
@@ -536,7 +545,8 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
536545
seed : None or int
537546
leave None to use package global RandomStream or other
538547
valid value to create instance specific one
539-
548+
start : Point
549+
starting point for inference
540550
Returns
541551
-------
542552
Approximation
@@ -554,7 +564,7 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
554564
raise ValueError('frac should be in (0, 1)')
555565
n1 = int(n * frac)
556566
n2 = n-n1
557-
inference = ADVI(local_rv=local_rv, model=model, seed=seed)
567+
inference = ADVI(local_rv=local_rv, model=model, seed=seed, start=start)
558568
logger.info('fitting advi ...')
559569
inference.fit(n1, **kwargs)
560570
inference = FullRankADVI.from_advi(inference)
@@ -564,7 +574,8 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
564574
elif isinstance(method, str):
565575
try:
566576
inference = _select[method.lower()](
567-
local_rv=local_rv, model=model, seed=seed
577+
local_rv=local_rv, model=model, seed=seed,
578+
start=start
568579
)
569580
except KeyError:
570581
raise KeyError('method should be one of %s '

0 commit comments

Comments
 (0)