Skip to content

Commit 72c50f4

Browse files
authored
refactor scalings (#2107)
* reset commits, fix the problem * pretty docs * add test for scaling argument
1 parent 56a5f1a commit 72c50f4

File tree

6 files changed

+86
-63
lines changed

6 files changed

+86
-63
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def cb(*_):
207207
data_t.set_value(next(minibatches))
208208
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
209209
Normal('x', mu=mu_, sd=sd, observed=data_t, total_size=n)
210-
inf = self.inference()
210+
inf = self.inference(scale_cost_to_minibatch=True)
211211
approx = inf.fit(self.NITER * 3, callbacks=
212212
[cb, pm.callbacks.CheckParametersConvergence()],
213213
obj_n_mc=10, obj_optimizer=self.optimizer)

pymc3/theanof.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
'join_nonshared_inputs',
2424
'make_shared_replacements',
2525
'generator',
26-
'GradScale',
2726
'set_tt_rng',
2827
'tt_rng']
2928

@@ -417,13 +416,5 @@ def set_tt_rng(new_rng):
417416
launch_rng(_tt_rng)
418417

419418

420-
class GradScale(theano.compile.ViewOp):
421-
def __init__(self, multiplier):
422-
self.multiplier = multiplier
423-
424-
def grad(self, args, g_outs):
425-
return [self.multiplier * g_out for g_out in g_outs]
426-
427-
428419
def floatX_array(x):
429420
return floatX(np.array(x))

pymc3/util.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ def get_transformed_name(name, transform):
66
----------
77
name : str
88
Name to transform
9-
transform : object
9+
transform : transforms.Transform
1010
Should be a subclass of `transforms.Transform`
1111
12-
Returns:
13-
A string to use for the transformed variable
12+
Returns
13+
-------
14+
str
15+
A string to use for the transformed variable
1416
"""
1517
return "{}_{}__".format(name, transform.name)
1618

@@ -24,8 +26,10 @@ def is_transformed_name(name):
2426
name : str
2527
Name to check
2628
27-
Returns:
28-
Boolean, whether the string could have been produced by `get_transormed_name`
29+
Returns
30+
-------
31+
bool
32+
Boolean, whether the string could have been produced by `get_transormed_name`
2933
"""
3034
return name.endswith('__') and name.count('_') >= 3
3135

@@ -39,8 +43,10 @@ def get_untransformed_name(name):
3943
name : str
4044
Name to untransform
4145
42-
Returns:
43-
String with untransformed version of the name.
46+
Returns
47+
-------
48+
str
49+
String with untransformed version of the name.
4450
"""
4551
if not is_transformed_name(name):
4652
raise ValueError(u'{} does not appear to be a transformed name'.format(name))
@@ -57,8 +63,10 @@ def get_default_varnames(var_iterator, include_transformed):
5763
include_transformed : boolean
5864
Should transformed variable names be included in return value
5965
60-
Returns:
61-
List of variables, possibly filtered
66+
Returns
67+
-------
68+
list
69+
List of variables, possibly filtered
6270
"""
6371
if include_transformed:
6472
return list(var_iterator)

pymc3/variational/approximations.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,23 @@ class MeanField(Approximation):
2828
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
2929
Local Vars are used for Autoencoding Variational Bayes
3030
See (AEVB; Kingma and Welling, 2014) for details
31-
3231
model : PyMC3 model for inference
33-
3432
start : Point
3533
initial mean
36-
3734
cost_part_grad_scale : float or scalar tensor
3835
Scaling score part of gradient can be useful near optimum for
3936
archiving better convergence properties. Common schedule is
4037
1 at the start and 0 in the end. So slow decay will be ok.
4138
See (Sticking the Landing; Geoffrey Roeder,
4239
Yuhuai Wu, David Duvenaud, 2016) for details
43-
40+
scale_cost_to_minibatch : bool, default False
41+
Scale cost to minibatch instead of full dataset
4442
seed : None or int
4543
leave None to use package global RandomStream or other
4644
valid value to create instance specific one
4745
4846
References
49-
----------
47+
----------
5048
Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
5149
Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI
5250
approximateinference.org/accepted/RoederEtAl2016.pdf
@@ -109,19 +107,17 @@ class FullRank(Approximation):
109107
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
110108
Local Vars are used for Autoencoding Variational Bayes
111109
See (AEVB; Kingma and Welling, 2014) for details
112-
113110
model : PyMC3 model for inference
114-
115111
start : Point
116112
initial mean
117-
118113
cost_part_grad_scale : float or scalar tensor
119114
Scaling score part of gradient can be useful near optimum for
120115
archiving better convergence properties. Common schedule is
121116
1 at the start and 0 in the end. So slow decay will be ok.
122117
See (Sticking the Landing; Geoffrey Roeder,
123-
Yuhuai Wu, David Duvenaud, 2016) for details
124-
118+
Yuhuai Wu, David Duvenaud, 2016) for details
119+
scale_cost_to_minibatch : bool, default False
120+
Scale cost to minibatch instead of full dataset
125121
seed : None or int
126122
leave None to use package global RandomStream or other
127123
valid value to create instance specific one
@@ -133,12 +129,13 @@ class FullRank(Approximation):
133129
approximateinference.org/accepted/RoederEtAl2016.pdf
134130
"""
135131
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
132+
scale_cost_to_minibatch=False,
136133
gpu_compat=False, seed=None, **kwargs):
137134
super(FullRank, self).__init__(
138135
local_rv=local_rv, model=model,
139136
cost_part_grad_scale=cost_part_grad_scale,
140-
seed=seed,
141-
**kwargs
137+
scale_cost_to_minibatch=scale_cost_to_minibatch,
138+
seed=seed, **kwargs
142139
)
143140
self.gpu_compat = gpu_compat
144141

@@ -213,7 +210,7 @@ def from_mean_field(cls, mean_field, gpu_compat=False):
213210
"""Construct FullRank from MeanField approximation
214211
215212
Parameters
216-
----------
213+
----------
217214
mean_field : MeanField
218215
approximation to start with
219216
@@ -256,9 +253,9 @@ class Empirical(Approximation):
256253
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
257254
Local Vars are used for Autoencoding Variational Bayes
258255
See (AEVB; Kingma and Welling, 2014) for details
259-
256+
scale_cost_to_minibatch : bool, default False
257+
Scale cost to minibatch instead of full dataset
260258
model : PyMC3 model
261-
262259
seed : None or int
263260
leave None to use package global RandomStream or other
264261
valid value to create instance specific one
@@ -270,11 +267,12 @@ class Empirical(Approximation):
270267
... trace = sample(1000, step=step)
271268
... histogram = Empirical(trace[100:])
272269
"""
273-
def __init__(self, trace, local_rv=None, model=None, seed=None, **kwargs):
270+
def __init__(self, trace, local_rv=None,
271+
scale_cost_to_minibatch=False,
272+
model=None, seed=None, **kwargs):
274273
super(Empirical, self).__init__(
275-
local_rv=local_rv, model=model, trace=trace, seed=seed,
276-
**kwargs
277-
)
274+
local_rv=local_rv, scale_cost_to_minibatch=scale_cost_to_minibatch,
275+
model=model, trace=trace, seed=seed, **kwargs)
278276

279277
def check_model(self, model, **kwargs):
280278
trace = kwargs.get('trace')
@@ -352,7 +350,8 @@ def cov(self):
352350
return x.T.dot(x) / self.histogram.shape[0]
353351

354352
@classmethod
355-
def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, seed=None):
353+
def from_noise(cls, size, jitter=.01, local_rv=None,
354+
start=None, model=None, seed=None, **kwargs):
356355
"""Initialize Histogram with random noise
357356
358357
Parameters
@@ -366,12 +365,16 @@ def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, see
366365
start : initial point
367366
model : pm.Model
368367
PyMC3 Model
368+
seed : None or int
369+
leave None to use package global RandomStream or other
370+
valid value to create instance specific one
371+
kwargs : other kwargs passed to init
369372
370373
Returns
371-
-------
374+
-------
372375
Empirical
373376
"""
374-
hist = cls(None, local_rv=local_rv, model=model, seed=seed)
377+
hist = cls(None, local_rv=local_rv, model=model, seed=seed, **kwargs)
375378
if start is None:
376379
start = hist.model.test_point
377380
else:
@@ -390,15 +393,15 @@ def sample_approx(approx, draws=100, include_transformed=True):
390393
"""Draw samples from variational posterior.
391394
392395
Parameters
393-
----------
396+
----------
394397
approx : Approximation
395398
draws : int
396399
Number of random samples.
397400
include_transformed : bool
398401
If True, transformed variables are also sampled. Default is True.
399402
400403
Returns
401-
-------
404+
-------
402405
trace : pymc3.backends.base.MultiTrace
403406
Samples drawn from variational posterior.
404407
"""

pymc3/variational/inference.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ class ADVI(Inference):
304304
1 at the start and 0 in the end. So slow decay will be ok.
305305
See (Sticking the Landing; Geoffrey Roeder,
306306
Yuhuai Wu, David Duvenaud, 2016) for details
307+
scale_cost_to_minibatch : bool, default False
308+
Scale cost to minibatch instead of full dataset
307309
seed : None or int
308310
leave None to use package global RandomStream or other
309311
valid value to create instance specific one
@@ -323,11 +325,15 @@ class ADVI(Inference):
323325
- Kingma, D. P., & Welling, M. (2014).
324326
Auto-Encoding Variational Bayes. stat, 1050, 1.
325327
"""
326-
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
328+
def __init__(self, local_rv=None, model=None,
329+
cost_part_grad_scale=1,
330+
scale_cost_to_minibatch=False,
327331
seed=None, start=None):
328332
super(ADVI, self).__init__(
329333
KL, MeanField, None,
330-
local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale,
334+
local_rv=local_rv, model=model,
335+
cost_part_grad_scale=cost_part_grad_scale,
336+
scale_cost_to_minibatch=scale_cost_to_minibatch,
331337
seed=seed, start=start)
332338

333339
@classmethod
@@ -372,7 +378,8 @@ class FullRankADVI(Inference):
372378
1 at the start and 0 in the end. So slow decay will be ok.
373379
See (Sticking the Landing; Geoffrey Roeder,
374380
Yuhuai Wu, David Duvenaud, 2016) for details
375-
381+
scale_cost_to_minibatch : bool, default False
382+
Scale cost to minibatch instead of full dataset
376383
seed : None or int
377384
leave None to use package global RandomStream or other
378385
valid value to create instance specific one
@@ -392,11 +399,15 @@ class FullRankADVI(Inference):
392399
- Kingma, D. P., & Welling, M. (2014).
393400
Auto-Encoding Variational Bayes. stat, 1050, 1.
394401
"""
395-
def __init__(self, local_rv=None, model=None, cost_part_grad_scale=1,
402+
def __init__(self, local_rv=None, model=None,
403+
cost_part_grad_scale=1,
404+
scale_cost_to_minibatch=False,
396405
gpu_compat=False, seed=None, start=None):
397406
super(FullRankADVI, self).__init__(
398407
KL, FullRank, None,
399-
local_rv=local_rv, model=model, cost_part_grad_scale=cost_part_grad_scale,
408+
local_rv=local_rv, model=model,
409+
cost_part_grad_scale=cost_part_grad_scale,
410+
scale_cost_to_minibatch=scale_cost_to_minibatch,
400411
gpu_compat=gpu_compat, seed=seed, start=start)
401412

402413
@classmethod
@@ -497,6 +508,8 @@ class SVGD(Inference):
497508
model : pm.Model
498509
kernel : callable
499510
kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
511+
scale_cost_to_minibatch : bool, default False
512+
Scale cost to minibatch instead of full dataset
500513
start : dict
501514
initial point for inference
502515
histogram : Empirical
@@ -514,10 +527,13 @@ class SVGD(Inference):
514527
arXiv:1608.04471
515528
"""
516529
def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_functions.rbf,
517-
start=None, histogram=None, seed=None, local_rv=None):
530+
scale_cost_to_minibatch=False, start=None, histogram=None,
531+
seed=None, local_rv=None):
518532
if histogram is None:
519533
histogram = Empirical.from_noise(
520-
n_particles, jitter=jitter, start=start, model=model, local_rv=local_rv, seed=seed)
534+
n_particles, jitter=jitter,
535+
scale_cost_to_minibatch=scale_cost_to_minibatch,
536+
start=start, model=model, local_rv=local_rv, seed=seed)
521537
super(SVGD, self).__init__(
522538
KSD, histogram,
523539
kernel,

0 commit comments

Comments
 (0)