@@ -257,9 +257,7 @@ def _infmean(input_array):
257
257
)
258
258
)
259
259
else :
260
- if n == 0 :
261
- logger .info (f"Initialization only" )
262
- elif n < 10 :
260
+ if n < 10 :
263
261
logger .info (f"Finished [100%]: Loss = { scores [- 1 ]:,.5g} " )
264
262
else :
265
263
avg_loss = _infmean (scores [max (0 , i - 1000 ) : i + 1 ])
@@ -435,10 +433,8 @@ class ADVI(KLqp):
435
433
random_seed: None or int
436
434
leave None to use package global RandomStream or other
437
435
valid value to create instance specific one
438
- start: `dict[str, np.ndarray]` or `StartDict `
436
+ start: `Point `
439
437
starting point for inference
440
- start_sigma: `dict[str, np.ndarray]`
441
- starting standard deviation for inference, only available for method 'advi'
442
438
443
439
References
444
440
----------
@@ -468,7 +464,7 @@ class FullRankADVI(KLqp):
468
464
random_seed: None or int
469
465
leave None to use package global RandomStream or other
470
466
valid value to create instance specific one
471
- start: `dict[str, np.ndarray]` or `StartDict `
467
+ start: `Point `
472
468
starting point for inference
473
469
474
470
References
@@ -536,11 +532,13 @@ class SVGD(ImplicitGradient):
536
532
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
537
533
temperature: float
538
534
parameter responsible for exploration, higher temperature gives more broad posterior estimate
539
- start: `dict[str, np.ndarray]` or `StartDict `
535
+ start: `dict`
540
536
initial point for inference
541
537
random_seed: None or int
542
538
leave None to use package global RandomStream or other
543
539
valid value to create instance specific one
540
+ start: `Point`
541
+ starting point for inference
544
542
kwargs: other keyword arguments passed to estimator
545
543
546
544
References
@@ -631,11 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
631
629
"is often **underestimated** when using temperature = 1."
632
630
)
633
631
if approx is None :
634
- approx = FullRank (
635
- model = kwargs .pop ("model" , None ),
636
- random_seed = kwargs .pop ("random_seed" , None ),
637
- start = kwargs .pop ("start" , None ),
638
- )
632
+ approx = FullRank (model = kwargs .pop ("model" , None ))
639
633
super ().__init__ (estimator = estimator , approx = approx , kernel = kernel , ** kwargs )
640
634
641
635
def fit (
@@ -666,7 +660,6 @@ def fit(
666
660
model = None ,
667
661
random_seed = None ,
668
662
start = None ,
669
- start_sigma = None ,
670
663
inf_kwargs = None ,
671
664
** kwargs ,
672
665
):
@@ -691,10 +684,8 @@ def fit(
691
684
valid value to create instance specific one
692
685
inf_kwargs: dict
693
686
additional kwargs passed to :class:`Inference`
694
- start: `dict[str, np.ndarray]` or `StartDict `
687
+ start: `Point `
695
688
starting point for inference
696
- start_sigma: `dict[str, np.ndarray]`
697
- starting standard deviation for inference, only available for method 'advi'
698
689
699
690
Other Parameters
700
691
----------------
@@ -737,10 +728,6 @@ def fit(
737
728
inf_kwargs ["random_seed" ] = random_seed
738
729
if start is not None :
739
730
inf_kwargs ["start" ] = start
740
- if start_sigma is not None :
741
- if method != "advi" :
742
- raise NotImplementedError ("start_sigma is only available for method advi" )
743
- inf_kwargs ["start_sigma" ] = start_sigma
744
731
if model is None :
745
732
model = pm .modelcontext (model )
746
733
_select = dict (advi = ADVI , fullrank_advi = FullRankADVI , svgd = SVGD , asvgd = ASVGD )
0 commit comments