Skip to content

Commit 17e3334

Browse files
ferrinetwiecki
authored andcommitted
pretty opvi
1 parent 7c033e5 commit 17e3334

File tree

1 file changed

+86
-64
lines changed

1 file changed

+86
-64
lines changed

pymc3/variational/opvi.py

Lines changed: 86 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
reveal the true nature of underlying problem. In some applications it can
66
yield unreliable decisions.
77
8-
Recently on NIPS 2017 [OPVI](https://arxiv.org/abs/1610.09033) framework
8+
Recently on NIPS 2017 `OPVI <https://arxiv.org/abs/1610.09033/>`_ framework
99
was presented. It generalizes variational inverence so that the problem is
1010
build with blocks. The first and essential block is Model itself. Second is
1111
Approximation, in some cases :math:`log Q(D)` is not really needed. Necessity
@@ -68,8 +68,10 @@ class ObjectiveFunction(object):
6868
6969
Parameters
7070
----------
71-
op : Operator
72-
tf : TestFunction
71+
op : :class:`Operator`
72+
OPVI Functional operator
73+
tf : :class:`TestFunction`
74+
OPVI TestFunction
7375
"""
7476
def __init__(self, op, tf):
7577
self.op = op
@@ -85,7 +87,7 @@ def random(self, size=None):
8587
8688
Parameters
8789
----------
88-
size : int
90+
size : `int`
8991
number of samples from distribution
9092
9193
Returns
@@ -101,26 +103,26 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adam, test_optimize
101103
102104
Parameters
103105
----------
104-
obj_n_mc : int
106+
obj_n_mc : `int`
105107
Number of monte carlo samples used for approximation of objective gradients
106-
tf_n_mc : int
108+
tf_n_mc : `int`
107109
Number of monte carlo samples used for approximation of test function gradients
108110
obj_optimizer : function (loss, params) -> updates
109111
Optimizer that is used for objective params
110112
test_optimizer : function (loss, params) -> updates
111113
Optimizer that is used for test function params
112-
more_obj_params : list
114+
more_obj_params : `list`
113115
Add custom params for objective optimizer
114-
more_tf_params : list
116+
more_tf_params : `list`
115117
Add custom params for test function optimizer
116-
more_updates : dict
118+
more_updates : `dict`
117119
Add custom updates to resulting updates
118-
more_replacements : dict
120+
more_replacements : `dict`
119121
Apply custom replacements before calculating gradients
120122
121123
Returns
122124
-------
123-
ObjectiveUpdates
125+
:class:`ObjectiveUpdates`
124126
"""
125127
if more_obj_params is None:
126128
more_obj_params = []
@@ -182,36 +184,37 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
182184
"""Step function that should be called on each optimization step.
183185
184186
Generally it solves the following problem:
187+
185188
.. math::
186189
187190
\textbf{\lambda^{*}} = \inf_{\lambda} \sup_{\theta} t(\mathbb{E}_{\lambda}[(O^{p,q}f_{\theta})(z)])
188191
189192
Parameters
190193
----------
191-
obj_n_mc : int
194+
obj_n_mc : `int`
192195
Number of monte carlo samples used for approximation of objective gradients
193-
tf_n_mc : int
196+
tf_n_mc : `int`
194197
Number of monte carlo samples used for approximation of test function gradients
195198
obj_optimizer : function (loss, params) -> updates
196199
Optimizer that is used for objective params
197200
test_optimizer : function (loss, params) -> updates
198201
Optimizer that is used for test function params
199-
more_obj_params : list
202+
more_obj_params : `list`
200203
Add custom params for objective optimizer
201-
more_tf_params : list
204+
more_tf_params : `list`
202205
Add custom params for test function optimizer
203-
more_updates : dict
206+
more_updates : `dict`
204207
Add custom updates to resulting updates
205-
score : bool
208+
score : `bool`
206209
calculate loss on each step? Defaults to False for speed
207-
fn_kwargs : dict
210+
fn_kwargs : `dict`
208211
Add kwargs to theano.function (e.g. `{'profile': True}`)
209-
more_replacements : dict
212+
more_replacements : `dict`
210213
Apply custom replacements before calculating gradients
211214
212215
Returns
213216
-------
214-
theano.function
217+
`theano.function`
215218
"""
216219
if fn_kwargs is None:
217220
fn_kwargs = {}
@@ -237,11 +240,11 @@ def score_function(self, sc_n_mc=None, more_replacements=None, fn_kwargs=None):
237240
238241
Parameters
239242
----------
240-
sc_n_mc : int
243+
sc_n_mc : `int`
241244
number of scoring MC samples
242245
more_replacements:
243246
Apply custom replacements before compiling a function
244-
fn_kwargs:
247+
fn_kwargs: `dict`
245248
arbitrary kwargs passed to theano.function
246249
247250
Returns
@@ -278,10 +281,11 @@ class Operator(object):
278281
279282
Parameters
280283
----------
281-
approx : Approximation
284+
approx : :class:`Approximation`
285+
an approximation instance
282286
283-
Subclassing
284-
-----------
287+
Notes
288+
-----
285289
For implementing Custom operator it is needed to define :code:`.apply(f)` method
286290
"""
287291

@@ -326,19 +330,21 @@ def logq_norm(self, z):
326330

327331
def apply(self, f): # pragma: no cover
328332
"""Operator itself
333+
329334
.. math::
330335
331336
(O^{p,q}f_{\theta})(z)
332337
333338
Parameters
334339
----------
335-
f : TestFunction or None if not required
340+
f : :class:`TestFunction` or None if not required
336341
function that takes `z = self.input` and returns
337342
same dimensional output
338343
339344
Returns
340345
-------
341-
symbolically applied operator
346+
tt.TensorVariable
347+
symbolically applied operator
342348
"""
343349
raise NotImplementedError
344350

@@ -426,7 +432,8 @@ def _setup(self, dim):
426432
427433
Parameters
428434
----------
429-
dim : int dimension of posterior distribution
435+
dim : int
436+
dimension of posterior distribution
430437
"""
431438
pass
432439

@@ -445,12 +452,11 @@ class Approximation(object):
445452
Parameters
446453
----------
447454
local_rv : dict[var->tuple]
448-
mapping {model_variable -> local_variable (:math:`\\mu`, math:`\\rho`)}
455+
mapping {model_variable -> local_variable (:math:`\\mu`, :math:`\\rho`)}
449456
Local Vars are used for Autoencoding Variational Bayes
450457
See (AEVB; Kingma and Welling, 2014) for details
451-
452-
model : PyMC3 model for inference
453-
458+
model : :class:`Model`
459+
PyMC3 model for inference
454460
cost_part_grad_scale : float or scalar tensor
455461
Scaling score part of gradient can be useful near optimum for
456462
archiving better convergence properties. Common schedule is
@@ -463,10 +469,11 @@ class Approximation(object):
463469
leave None to use package global RandomStream or other
464470
valid value to create instance specific one
465471
466-
Subclassing
467-
-----------
472+
Notes
473+
-----
468474
Defining an approximation needs
469475
custom implementation of the following methods:
476+
470477
- :code:`.create_shared_params(**kwargs)`
471478
Returns {dict|list|theano.shared}
472479
@@ -481,19 +488,21 @@ class Approximation(object):
481488
Returns Scalar
482489
483490
You can also override the following methods:
491+
484492
- :code:`._setup(**kwargs)`
485493
Do some specific stuff having :code:`kwargs` before calling :code:`.create_shared_params`
486494
487495
- :code:`.check_model(model, **kwargs)`
488496
Do some specific check for model having :code:`kwargs`
489497
490-
Notes
491-
-----
498+
See Also
499+
--------
492500
:code:`kwargs` mentioned above are supplied as additional arguments
493501
for :code:`Approximation.__init__`
494502
495503
There are some defaults class attributes for approximation classes that can be
496504
optionally overriden.
505+
497506
- :code:`initial_dist_name`
498507
string that represents name of the initial distribution.
499508
In most cases if will be `uniform` or `normal`
@@ -553,7 +562,7 @@ def seed(self, seed=None):
553562
554563
Parameters
555564
----------
556-
seed : int
565+
seed : `int`
557566
"""
558567
self._seed = seed
559568
self._rng.seed(seed)
@@ -609,16 +618,16 @@ def construct_replacements(self, include=None, exclude=None,
609618
610619
Parameters
611620
----------
612-
include : list
621+
include : `list`
613622
latent variables to be replaced
614-
exclude : list
623+
exclude : `list`
615624
latent variables to be excluded for replacements
616-
more_replacements : dict
625+
more_replacements : `dict`
617626
add custom replacements to graph, e.g. change input source
618627
619628
Returns
620629
-------
621-
dict
630+
`dict`
622631
Replacements
623632
"""
624633
if include is not None and exclude is not None:
@@ -647,11 +656,11 @@ def apply_replacements(self, node, deterministic=False,
647656
deterministic : bool
648657
whether to use zeros as initial distribution
649658
if True - zero initial point will produce constant latent variables
650-
include : list
659+
include : `list`
651660
latent variables to be replaced
652-
exclude : list
661+
exclude : `list`
653662
latent variables to be excluded for replacements
654-
more_replacements : dict
663+
more_replacements : `dict`
655664
add custom replacements to graph, e.g. change input source
656665
657666
Returns
@@ -674,7 +683,7 @@ def sample_node(self, node, size=100,
674683
node : Theano Variables (or Theano expressions)
675684
size : scalar
676685
number of samples
677-
more_replacements : dict
686+
more_replacements : `dict`
678687
add custom replacements to graph, e.g. change input source
679688
680689
Returns
@@ -716,13 +725,16 @@ def initial(self, size, no_rand=False, l=None):
716725
717726
Parameters
718727
----------
719-
size : int - number of samples
720-
no_rand : bool - return zeros if True
721-
l : length of sample, defaults to latent space dim
728+
size : `int`
729+
number of samples
730+
no_rand : `bool`
731+
return zeros if True
732+
l : `int`
733+
length of sample, defaults to latent space dim
722734
723735
Returns
724736
-------
725-
Tensor
737+
`tt.TensorVariable`
726738
sampled latent space shape == size + latent_dim
727739
"""
728740

@@ -754,8 +766,10 @@ def random_local(self, size=None, no_rand=False):
754766
755767
Parameters
756768
----------
757-
size : number of samples from distribution
758-
no_rand : whether use deterministic distribution
769+
size : `scalar`
770+
number of samples from distribution
771+
no_rand : `bool`
772+
whether use deterministic distribution
759773
760774
Returns
761775
-------
@@ -771,8 +785,10 @@ def random_global(self, size=None, no_rand=False): # pragma: no cover
771785
772786
Parameters
773787
----------
774-
size : number of samples from distribution
775-
no_rand : whether use deterministic distribution
788+
size : `scalar`
789+
number of samples from distribution
790+
no_rand : `bool`
791+
whether use deterministic distribution
776792
777793
Returns
778794
-------
@@ -785,8 +801,10 @@ def random(self, size=None, no_rand=False):
785801
786802
Parameters
787803
----------
788-
size : number of samples from distribution
789-
no_rand : whether use deterministic distribution
804+
size : `scalar`
805+
number of samples from distribution
806+
no_rand : `bool`
807+
whether use deterministic distribution
790808
791809
Returns
792810
-------
@@ -816,8 +834,10 @@ def random_fn(self):
816834
817835
Parameters
818836
----------
819-
size : number of samples from distribution
820-
no_rand : whether use deterministic distribution
837+
size : `int`
838+
number of samples from distribution
839+
no_rand : `bool`
840+
whether use deterministic distribution
821841
822842
Returns
823843
-------
@@ -844,14 +864,14 @@ def sample(self, draws=1, include_transformed=False):
844864
845865
Parameters
846866
----------
847-
draws : int
867+
draws : `int`
848868
Number of random samples.
849-
include_transformed : bool
869+
include_transformed : `bool`
850870
If True, transformed variables are also sampled. Default is False.
851871
852872
Returns
853873
-------
854-
trace : pymc3.backends.base.MultiTrace
874+
trace : :class:`pymc3.backends.base.MultiTrace`
855875
Samples drawn from variational posterior.
856876
"""
857877
vars_sampled = get_default_varnames(self.model.unobserved_RVs,
@@ -910,15 +930,17 @@ def view(self, space, name, reshape=True):
910930
911931
Parameters
912932
----------
913-
space : space to take view of variable from
914-
name : str
933+
space : matrix or vector
934+
space to take view of variable from
935+
name : `str`
915936
name of variable
916-
reshape : bool
937+
reshape : `bool`
917938
whether to reshape variable from vectorized view
918939
919940
Returns
920941
-------
921-
variable view
942+
(reshaped) slice of matrix
943+
variable view
922944
"""
923945
theano_is_here = isinstance(space, tt.TensorVariable)
924946
slc = self._view[name].slc

0 commit comments

Comments
 (0)