@@ -28,25 +28,23 @@ class MeanField(Approximation):
28
28
mapping {model_variable -> local_variable (:math:`\\ mu`, :math:`\\ rho`)}
29
29
Local Vars are used for Autoencoding Variational Bayes
30
30
See (AEVB; Kingma and Welling, 2014) for details
31
-
32
31
model : PyMC3 model for inference
33
-
34
32
start : Point
35
33
initial mean
36
-
37
34
cost_part_grad_scale : float or scalar tensor
38
35
Scaling score part of gradient can be useful near optimum for
39
36
archiving better convergence properties. Common schedule is
40
37
1 at the start and 0 in the end. So slow decay will be ok.
41
38
See (Sticking the Landing; Geoffrey Roeder,
42
39
Yuhuai Wu, David Duvenaud, 2016) for details
43
-
40
+ scale_cost_to_minibatch : bool, default False
41
+ Scale cost to minibatch instead of full dataset
44
42
seed : None or int
45
43
leave None to use package global RandomStream or other
46
44
valid value to create instance specific one
47
45
48
46
References
49
- ----------
47
+ ----------
50
48
Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
51
49
Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI
52
50
approximateinference.org/accepted/RoederEtAl2016.pdf
@@ -109,19 +107,17 @@ class FullRank(Approximation):
109
107
mapping {model_variable -> local_variable (:math:`\\ mu`, :math:`\\ rho`)}
110
108
Local Vars are used for Autoencoding Variational Bayes
111
109
See (AEVB; Kingma and Welling, 2014) for details
112
-
113
110
model : PyMC3 model for inference
114
-
115
111
start : Point
116
112
initial mean
117
-
118
113
cost_part_grad_scale : float or scalar tensor
119
114
Scaling score part of gradient can be useful near optimum for
120
115
archiving better convergence properties. Common schedule is
121
116
1 at the start and 0 in the end. So slow decay will be ok.
122
117
See (Sticking the Landing; Geoffrey Roeder,
123
- Yuhuai Wu, David Duvenaud, 2016) for details
124
-
118
+ Yuhuai Wu, David Duvenaud, 2016) for details
119
+ scale_cost_to_minibatch : bool, default False
120
+ Scale cost to minibatch instead of full dataset
125
121
seed : None or int
126
122
leave None to use package global RandomStream or other
127
123
valid value to create instance specific one
@@ -133,12 +129,13 @@ class FullRank(Approximation):
133
129
approximateinference.org/accepted/RoederEtAl2016.pdf
134
130
"""
135
131
def __init__ (self , local_rv = None , model = None , cost_part_grad_scale = 1 ,
132
+ scale_cost_to_minibatch = False ,
136
133
gpu_compat = False , seed = None , ** kwargs ):
137
134
super (FullRank , self ).__init__ (
138
135
local_rv = local_rv , model = model ,
139
136
cost_part_grad_scale = cost_part_grad_scale ,
140
- seed = seed ,
141
- ** kwargs
137
+ scale_cost_to_minibatch = scale_cost_to_minibatch ,
138
+ seed = seed , ** kwargs
142
139
)
143
140
self .gpu_compat = gpu_compat
144
141
@@ -213,7 +210,7 @@ def from_mean_field(cls, mean_field, gpu_compat=False):
213
210
"""Construct FullRank from MeanField approximation
214
211
215
212
Parameters
216
- ----------
213
+ ----------
217
214
mean_field : MeanField
218
215
approximation to start with
219
216
@@ -256,9 +253,9 @@ class Empirical(Approximation):
256
253
mapping {model_variable -> local_variable (:math:`\\ mu`, :math:`\\ rho`)}
257
254
Local Vars are used for Autoencoding Variational Bayes
258
255
See (AEVB; Kingma and Welling, 2014) for details
259
-
256
+ scale_cost_to_minibatch : bool, default False
257
+ Scale cost to minibatch instead of full dataset
260
258
model : PyMC3 model
261
-
262
259
seed : None or int
263
260
leave None to use package global RandomStream or other
264
261
valid value to create instance specific one
@@ -270,11 +267,12 @@ class Empirical(Approximation):
270
267
... trace = sample(1000, step=step)
271
268
... histogram = Empirical(trace[100:])
272
269
"""
273
- def __init__ (self , trace , local_rv = None , model = None , seed = None , ** kwargs ):
270
+ def __init__ (self , trace , local_rv = None ,
271
+ scale_cost_to_minibatch = False ,
272
+ model = None , seed = None , ** kwargs ):
274
273
super (Empirical , self ).__init__ (
275
- local_rv = local_rv , model = model , trace = trace , seed = seed ,
276
- ** kwargs
277
- )
274
+ local_rv = local_rv , scale_cost_to_minibatch = scale_cost_to_minibatch ,
275
+ model = model , trace = trace , seed = seed , ** kwargs )
278
276
279
277
def check_model (self , model , ** kwargs ):
280
278
trace = kwargs .get ('trace' )
@@ -352,7 +350,8 @@ def cov(self):
352
350
return x .T .dot (x ) / self .histogram .shape [0 ]
353
351
354
352
@classmethod
355
- def from_noise (cls , size , jitter = .01 , local_rv = None , start = None , model = None , seed = None ):
353
+ def from_noise (cls , size , jitter = .01 , local_rv = None ,
354
+ start = None , model = None , seed = None , ** kwargs ):
356
355
"""Initialize Histogram with random noise
357
356
358
357
Parameters
@@ -366,12 +365,16 @@ def from_noise(cls, size, jitter=.01, local_rv=None, start=None, model=None, see
366
365
start : initial point
367
366
model : pm.Model
368
367
PyMC3 Model
368
+ seed : None or int
369
+ leave None to use package global RandomStream or other
370
+ valid value to create instance specific one
371
+ kwargs : other kwargs passed to init
369
372
370
373
Returns
371
- -------
374
+ -------
372
375
Empirical
373
376
"""
374
- hist = cls (None , local_rv = local_rv , model = model , seed = seed )
377
+ hist = cls (None , local_rv = local_rv , model = model , seed = seed , ** kwargs )
375
378
if start is None :
376
379
start = hist .model .test_point
377
380
else :
@@ -390,15 +393,15 @@ def sample_approx(approx, draws=100, include_transformed=True):
390
393
"""Draw samples from variational posterior.
391
394
392
395
Parameters
393
- ----------
396
+ ----------
394
397
approx : Approximation
395
398
draws : int
396
399
Number of random samples.
397
400
include_transformed : bool
398
401
If True, transformed variables are also sampled. Default is True.
399
402
400
403
Returns
401
- -------
404
+ -------
402
405
trace : pymc3.backends.base.MultiTrace
403
406
Samples drawn from variational posterior.
404
407
"""
0 commit comments