23
23
24
24
@Group .register
25
25
class MeanFieldGroup (Group ):
26
- R"""Mean Field Group
27
-
28
- Mean Field approximation to the posterior where spherical Gaussian family
26
+ R"""Mean Field approximation to the posterior where spherical Gaussian family
29
27
is fitted to minimize KL divergence from True posterior. It is assumed
30
28
that latent space variables are uncorrelated that is the main drawback
31
29
of the method
@@ -101,9 +99,7 @@ def symbolic_logq_not_scaled(self):
101
99
102
100
@Group .register
103
101
class FullRankGroup (Group ):
104
- """Full Rank Group
105
-
106
- Full Rank approximation to the posterior where Multivariate Gaussian family
102
+ """Full Rank approximation to the posterior where Multivariate Gaussian family
107
103
is fitted to minimize KL divergence from True posterior. In contrast to
108
104
MeanField approach correlations between variables are taken in account. The
109
105
main drawback of the method is computational cost.
@@ -216,9 +212,7 @@ def symbolic_random(self):
216
212
217
213
@Group .register
218
214
class EmpiricalGroup (Group ):
219
- """Empirical Group
220
-
221
- Builds Approximation instance from a given trace,
215
+ """Builds Approximation instance from a given trace,
222
216
it has the same interface as variational approximation
223
217
"""
224
218
supports_batched = False
@@ -336,13 +330,12 @@ def __str__(self):
336
330
337
331
338
332
class NormalizingFlowGroup (Group ):
339
- R"""Normalizing Flow Group
340
-
341
- Normalizing flow is a series of invertible transformations on initial distribution.
333
+ R"""Normalizing flow is a series of invertible transformations on initial distribution.
342
334
343
335
.. math::
344
336
345
- z_K = f_K \circ \dots \circ f_2 \circ f_1(z_0)
337
+ z_K &= f_K \circ \dots \circ f_2 \circ f_1(z_0) \\
338
+ & z_0 \sim \mathcal{N}(0, 1)
346
339
347
340
In that case we can compute tractable density for the flow.
348
341
@@ -548,17 +541,23 @@ def __getattr__(self, item):
548
541
549
542
550
543
class MeanField (SingleGroupApproximation ):
551
- """Single Group Mean Field Approximation"""
544
+ __doc__ = """**Single Group Mean Field Approximation**
545
+
546
+ """ + str (MeanFieldGroup .__doc__ )
552
547
_group_class = MeanFieldGroup
553
548
554
549
555
550
class FullRank (SingleGroupApproximation ):
556
- """Single Group Full Rank Approximation"""
551
+ __doc__ = """**Single Group Full Rank Approximation**
552
+
553
+ """ + str (FullRankGroup .__doc__ )
557
554
_group_class = FullRankGroup
558
555
559
556
560
557
class Empirical (SingleGroupApproximation ):
561
- """Single Group Full Rank Approximation"""
558
+ __doc__ = """**Single Group Full Rank Approximation**
559
+
560
+ """ + str (EmpiricalGroup .__doc__ )
562
561
_group_class = EmpiricalGroup
563
562
564
563
def __init__ (self , trace = None , size = None , ** kwargs ):
@@ -568,7 +567,9 @@ def __init__(self, trace=None, size=None, **kwargs):
568
567
569
568
570
569
class NormalizingFlow (SingleGroupApproximation ):
571
- """Single Group Normalizing Flow Approximation"""
570
+ __doc__ = """**Single Group Normalizing Flow Approximation**
571
+
572
+ """ + str (NormalizingFlowGroup .__doc__ )
572
573
_group_class = NormalizingFlowGroup
573
574
574
575
def __init__ (self , flow = NormalizingFlowGroup .default_flow , * args , ** kwargs ):
0 commit comments