Skip to content

Commit ba194cd

Browse files
ferrinetwiecki
authored andcommitted
pretty inference
1 parent 17e3334 commit ba194cd

File tree

1 file changed

+62
-50
lines changed

1 file changed

+62
-50
lines changed

pymc3/variational/inference.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -271,21 +271,28 @@ class ADVI(Inference):
271271
observed variables with different :code:`total_size` and iterate them independently
272272
during inference.
273273
274-
For working with ADVI, we need to give
274+
For working with ADVI, we need to give
275+
275276
- The probabilistic model
276-
(:code:`model`), the three types of RVs (:code:`observed_RVs`,
277+
278+
:code:`model` with three types of RVs (:code:`observed_RVs`,
277279
:code:`global_RVs` and :code:`local_RVs`).
278280
279281
- (optional) Minibatches
282+
280283
The tensors to which mini-bathced samples are supplied are
281284
handled separately by using callbacks in :code:`.fit` method
282285
that change storage of shared theano variable or by :code:`pm.generator`
283286
that automatically iterates over minibatches and defined beforehand.
284287
285288
- (optional) Parameters of deterministic mappings
289+
286290
They have to be passed along with other params to :code:`.fit` method
287291
as :code:`more_obj_params` argument.
288292
293+
294+
See Also
295+
--------
289296
For more information concerning training stage please reference
290297
:code:`pymc3.variational.opvi.ObjectiveFunction.step_function`
291298
@@ -295,35 +302,34 @@ class ADVI(Inference):
295302
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
296303
Local Vars are used for Autoencoding Variational Bayes
297304
See (AEVB; Kingma and Welling, 2014) for details
298-
299-
model : PyMC3 model for inference
300-
301-
cost_part_grad_scale : float or scalar tensor
305+
model : :class:`Model`
306+
PyMC3 model for inference
307+
cost_part_grad_scale : `scalar`
302308
Scaling score part of gradient can be useful near optimum for
303309
archiving better convergence properties. Common schedule is
304310
1 at the start and 0 in the end. So slow decay will be ok.
305311
See (Sticking the Landing; Geoffrey Roeder,
306312
Yuhuai Wu, David Duvenaud, 2016) for details
307-
scale_cost_to_minibatch : bool, default False
308-
Scale cost to minibatch instead of full dataset
313+
scale_cost_to_minibatch : `bool`
314+
Scale cost to minibatch instead of full dataset, default False
309315
seed : None or int
310316
leave None to use package global RandomStream or other
311317
valid value to create instance specific one
312-
start : Point
318+
start : `Point`
313319
starting point for inference
314320
315321
References
316322
----------
317-
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A.,
323+
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A.,
318324
and Blei, D. M. (2016). Automatic Differentiation Variational
319325
Inference. arXiv preprint arXiv:1603.00788.
320326
321-
- Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
327+
- Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
322328
Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI
323329
approximateinference.org/accepted/RoederEtAl2016.pdf
324330
325-
- Kingma, D. P., & Welling, M. (2014).
326-
Auto-Encoding Variational Bayes. stat, 1050, 1.
331+
- Kingma, D. P., & Welling, M. (2014).
332+
Auto-Encoding Variational Bayes. stat, 1050, 1.
327333
"""
328334
def __init__(self, local_rv=None, model=None,
329335
cost_part_grad_scale=1,
@@ -343,12 +349,12 @@ def from_mean_field(cls, mean_field):
343349
344350
Parameters
345351
----------
346-
mean_field : MeanField
352+
mean_field : :class:`MeanField`
347353
approximation to start with
348354
349355
Returns
350356
-------
351-
ADVI
357+
:class:`ADVI`
352358
"""
353359
if not isinstance(mean_field, MeanField):
354360
raise TypeError('Expected MeanField, got %r' % mean_field)
@@ -369,10 +375,9 @@ class FullRankADVI(Inference):
369375
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
370376
Local Vars are used for Autoencoding Variational Bayes
371377
See (AEVB; Kingma and Welling, 2014) for details
372-
373-
model : PyMC3 model for inference
374-
375-
cost_part_grad_scale : float or scalar tensor
378+
model : :class:`Model`
379+
PyMC3 model for inference
380+
cost_part_grad_scale : `scalar`
376381
Scaling score part of gradient can be useful near optimum for
377382
archiving better convergence properties. Common schedule is
378383
1 at the start and 0 in the end. So slow decay will be ok.
@@ -383,21 +388,21 @@ class FullRankADVI(Inference):
383388
seed : None or int
384389
leave None to use package global RandomStream or other
385390
valid value to create instance specific one
386-
start : Point
391+
start : `Point`
387392
starting point for inference
388393
389394
References
390395
----------
391-
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A.,
396+
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A.,
392397
and Blei, D. M. (2016). Automatic Differentiation Variational
393398
Inference. arXiv preprint arXiv:1603.00788.
394399
395-
- Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
400+
- Geoffrey Roeder, Yuhuai Wu, David Duvenaud, 2016
396401
Sticking the Landing: A Simple Reduced-Variance Gradient for ADVI
397402
approximateinference.org/accepted/RoederEtAl2016.pdf
398403
399-
- Kingma, D. P., & Welling, M. (2014).
400-
Auto-Encoding Variational Bayes. stat, 1050, 1.
404+
- Kingma, D. P., & Welling, M. (2014).
405+
Auto-Encoding Variational Bayes. stat, 1050, 1.
401406
"""
402407
def __init__(self, local_rv=None, model=None,
403408
cost_part_grad_scale=1,
@@ -417,12 +422,12 @@ def from_full_rank(cls, full_rank):
417422
418423
Parameters
419424
----------
420-
full_rank : FullRank
425+
full_rank : :class:`FullRank`
421426
approximation to start with
422427
423428
Returns
424429
-------
425-
FullRankADVI
430+
:class:`FullRankADVI`
426431
"""
427432
if not isinstance(full_rank, FullRank):
428433
raise TypeError('Expected MeanField, got %r' % full_rank)
@@ -439,17 +444,17 @@ def from_mean_field(cls, mean_field, gpu_compat=False):
439444
440445
Parameters
441446
----------
442-
mean_field : MeanField
447+
mean_field : :class:`MeanField`
443448
approximation to start with
444449
445-
Flags
446-
-----
447-
gpu_compat : bool
450+
Other Parameters
451+
----------------
452+
gpu_compat : `bool`
448453
use GPU compatible version or not
449454
450455
Returns
451456
-------
452-
FullRankADVI
457+
:class:`FullRankADVI`
453458
"""
454459
full_rank = FullRank.from_mean_field(mean_field, gpu_compat)
455460
inference = object.__new__(cls)
@@ -465,16 +470,16 @@ def from_advi(cls, advi, gpu_compat=False):
465470
466471
Parameters
467472
----------
468-
advi : ADVI
473+
advi : :class:`ADVI`
469474
470-
Flags
471-
-----
475+
Other Parameters
476+
----------------
472477
gpu_compat : bool
473478
use GPU compatible version or not
474479
475480
Returns
476481
-------
477-
FullRankADVI
482+
:class:`FullRankADVI`
478483
"""
479484
inference = cls.from_mean_field(advi.approx, gpu_compat)
480485
inference.hist = advi.hist
@@ -494,35 +499,37 @@ class SVGD(Inference):
494499
Input: A target distribution with density function :math:`p(x)`
495500
and a set of initial particles :math:`{x^0_i}^n_{i=1}`
496501
Output: A set of particles :math:`{x_i}^n_{i=1}` that approximates the target distribution.
502+
497503
.. math::
498504
499505
x_i^{l+1} \leftarrow \epsilon_l \hat{\phi}^{*}(x_i^l)
500506
\hat{\phi}^{*}(x) = \frac{1}{n}\sum^{n}_{j=1}[k(x^l_j,x) \nabla_{x^l_j} logp(x^l_j)+ \nabla_{x^l_j} k(x^l_j,x)]
501507
502508
Parameters
503509
----------
504-
n_particles : int
510+
n_particles : `int`
505511
number of particles to use for approximation
506-
jitter :
512+
jitter : `float`
507513
noise sd for initial point
508-
model : pm.Model
509-
kernel : callable
514+
model : :class:`Model`
515+
PyMC3 model for inference
516+
kernel : `callable`
510517
kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
511518
scale_cost_to_minibatch : bool, default False
512519
Scale cost to minibatch instead of full dataset
513-
start : dict
520+
start : `dict`
514521
initial point for inference
515-
histogram : Empirical
522+
histogram : :class:`Empirical`
516523
initialize SVGD with given Empirical approximation instead of default initial particles
517524
seed : None or int
518525
leave None to use package global RandomStream or other
519526
valid value to create instance specific one
520-
start : Point
527+
start : `Point`
521528
starting point for inference
522529
523530
References
524531
----------
525-
- Qiang Liu, Dilin Wang (2016)
532+
- Qiang Liu, Dilin Wang (2016)
526533
Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
527534
arXiv:1608.04471
528535
"""
@@ -546,26 +553,31 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, start=None
546553
547554
Parameters
548555
----------
549-
n : int
556+
n : `int`
550557
number of iterations
551558
local_rv : dict[var->tuple]
552559
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
553560
Local Vars are used for Autoencoding Variational Bayes
554561
See (AEVB; Kingma and Welling, 2014) for details
555-
method : str or Inference
562+
method : str or :class:`Inference`
556563
string name is case insensitive in {'advi', 'fullrank_advi', 'advi->fullrank_advi'}
557-
model : Model
558-
kwargs : kwargs for Inference.fit
559-
frac : float
564+
model : :class:`Model`
565+
PyMC3 model for inference
566+
567+
Other Parameters
568+
----------------
569+
frac : `float`
560570
if method is 'advi->fullrank_advi' represents advi fraction when training
561571
seed : None or int
562572
leave None to use package global RandomStream or other
563573
valid value to create instance specific one
564-
start : Point
574+
start : `Point`
565575
starting point for inference
576+
kwargs : kwargs for :method:`Inference.fit`
577+
566578
Returns
567579
-------
568-
Approximation
580+
:class:`Approximation`
569581
"""
570582
if model is None:
571583
model = pm.modelcontext(model)

0 commit comments

Comments
 (0)