@@ -257,7 +257,9 @@ def _infmean(input_array):
257
257
)
258
258
)
259
259
else :
260
- if n < 10 :
260
+ if n == 0 :
261
+ logger .info (f"Initialization only" )
262
+ elif n < 10 :
261
263
logger .info (f"Finished [100%]: Loss = { scores [- 1 ]:,.5g} " )
262
264
else :
263
265
avg_loss = _infmean (scores [max (0 , i - 1000 ) : i + 1 ])
@@ -433,8 +435,10 @@ class ADVI(KLqp):
433
435
random_seed: None or int
434
436
leave None to use package global RandomStream or other
435
437
valid value to create instance specific one
436
- start: `Point `
438
+ start: `dict[str, np.ndarray]` or `StartDict `
437
439
starting point for inference
440
+ start_sigma: `dict[str, np.ndarray]`
441
+ starting standard deviation for inference, only available for method 'advi'
438
442
439
443
References
440
444
----------
@@ -464,7 +468,7 @@ class FullRankADVI(KLqp):
464
468
random_seed: None or int
465
469
leave None to use package global RandomStream or other
466
470
valid value to create instance specific one
467
- start: `Point `
471
+ start: `dict[str, np.ndarray]` or `StartDict `
468
472
starting point for inference
469
473
470
474
References
@@ -532,13 +536,11 @@ class SVGD(ImplicitGradient):
532
536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
533
537
temperature: float
534
538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
535
- start: `dict`
539
+ start: `dict[str, np.ndarray]` or `StartDict `
536
540
initial point for inference
537
541
random_seed: None or int
538
542
leave None to use package global RandomStream or other
539
543
valid value to create instance specific one
540
- start: `Point`
541
- starting point for inference
542
544
kwargs: other keyword arguments passed to estimator
543
545
544
546
References
@@ -629,7 +631,11 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
629
631
"is often **underestimated** when using temperature = 1."
630
632
)
631
633
if approx is None :
632
- approx = FullRank (model = kwargs .pop ("model" , None ))
634
+ approx = FullRank (
635
+ model = kwargs .pop ("model" , None ),
636
+ random_seed = kwargs .pop ("random_seed" , None ),
637
+ start = kwargs .pop ("start" , None ),
638
+ )
633
639
super ().__init__ (estimator = estimator , approx = approx , kernel = kernel , ** kwargs )
634
640
635
641
def fit (
@@ -660,6 +666,7 @@ def fit(
660
666
model = None ,
661
667
random_seed = None ,
662
668
start = None ,
669
+ start_sigma = None ,
663
670
inf_kwargs = None ,
664
671
** kwargs ,
665
672
):
@@ -684,8 +691,10 @@ def fit(
684
691
valid value to create instance specific one
685
692
inf_kwargs: dict
686
693
additional kwargs passed to :class:`Inference`
687
- start: `Point `
694
+ start: `dict[str, np.ndarray]` or `StartDict `
688
695
starting point for inference
696
+ start_sigma: `dict[str, np.ndarray]`
697
+ starting standard deviation for inference, only available for method 'advi'
689
698
690
699
Other Parameters
691
700
----------------
@@ -728,6 +737,10 @@ def fit(
728
737
inf_kwargs ["random_seed" ] = random_seed
729
738
if start is not None :
730
739
inf_kwargs ["start" ] = start
740
+ if start_sigma is not None :
741
+ if method != "advi" :
742
+ raise NotImplementedError ("start_sigma is only available for method advi" )
743
+ inf_kwargs ["start_sigma" ] = start_sigma
731
744
if model is None :
732
745
model = pm .modelcontext (model )
733
746
_select = dict (advi = ADVI , fullrank_advi = FullRankADVI , svgd = SVGD , asvgd = ASVGD )
0 commit comments