@@ -306,7 +306,9 @@ class ADVI(Inference):
306
306
Yuhuai Wu, David Duvenaud, 2016) for details
307
307
seed : None or int
308
308
leave None to use package global RandomStream or other
309
- valid value to create instance specific one
309
+ valid value to create instance specific one
310
+ start : Point
311
+ starting point for inference
310
312
311
313
References
312
314
----------
@@ -321,10 +323,12 @@ class ADVI(Inference):
321
323
- Kingma, D. P., & Welling, M. (2014).
322
324
Auto-Encoding Variational Bayes. stat, 1050, 1.
323
325
"""
324
- def __init__ (self , local_rv = None , model = None , cost_part_grad_scale = 1 , seed = None ):
326
+ def __init__ (self , local_rv = None , model = None , cost_part_grad_scale = 1 ,
327
+ seed = None , start = None ):
325
328
super (ADVI , self ).__init__ (
326
329
KL , MeanField , None ,
327
- local_rv = local_rv , model = model , cost_part_grad_scale = cost_part_grad_scale , seed = seed )
330
+ local_rv = local_rv , model = model , cost_part_grad_scale = cost_part_grad_scale ,
331
+ seed = seed , start = start )
328
332
329
333
@classmethod
330
334
def from_mean_field (cls , mean_field ):
@@ -372,6 +376,8 @@ class FullRankADVI(Inference):
372
376
seed : None or int
373
377
leave None to use package global RandomStream or other
374
378
valid value to create instance specific one
379
+ start : Point
380
+ starting point for inference
375
381
376
382
References
377
383
----------
@@ -386,11 +392,12 @@ class FullRankADVI(Inference):
386
392
- Kingma, D. P., & Welling, M. (2014).
387
393
Auto-Encoding Variational Bayes. stat, 1050, 1.
388
394
"""
389
- def __init__ (self , local_rv = None , model = None , cost_part_grad_scale = 1 , gpu_compat = False , seed = None ):
395
+ def __init__ (self , local_rv = None , model = None , cost_part_grad_scale = 1 ,
396
+ gpu_compat = False , seed = None , start = None ):
390
397
super (FullRankADVI , self ).__init__ (
391
398
KL , FullRank , None ,
392
399
local_rv = local_rv , model = model , cost_part_grad_scale = cost_part_grad_scale ,
393
- gpu_compat = gpu_compat , seed = seed )
400
+ gpu_compat = gpu_compat , seed = seed , start = start )
394
401
395
402
@classmethod
396
403
def from_full_rank (cls , full_rank ):
@@ -497,6 +504,8 @@ class SVGD(Inference):
497
504
seed : None or int
498
505
leave None to use package global RandomStream or other
499
506
valid value to create instance specific one
507
+ start : Point
508
+ starting point for inference
500
509
501
510
References
502
511
----------
@@ -515,7 +524,7 @@ def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_function
515
524
model = model , seed = seed )
516
525
517
526
518
- def fit (n = 10000 , local_rv = None , method = 'advi' , model = None , seed = None , ** kwargs ):
527
+ def fit (n = 10000 , local_rv = None , method = 'advi' , model = None , seed = None , start = None , ** kwargs ):
519
528
"""
520
529
Handy shortcut for using inference methods in functional way
521
530
@@ -536,7 +545,8 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
536
545
seed : None or int
537
546
leave None to use package global RandomStream or other
538
547
valid value to create instance specific one
539
-
548
+ start : Point
549
+ starting point for inference
540
550
Returns
541
551
-------
542
552
Approximation
@@ -554,7 +564,7 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
554
564
raise ValueError ('frac should be in (0, 1)' )
555
565
n1 = int (n * frac )
556
566
n2 = n - n1
557
- inference = ADVI (local_rv = local_rv , model = model , seed = seed )
567
+ inference = ADVI (local_rv = local_rv , model = model , seed = seed , start = start )
558
568
logger .info ('fitting advi ...' )
559
569
inference .fit (n1 , ** kwargs )
560
570
inference = FullRankADVI .from_advi (inference )
@@ -564,7 +574,8 @@ def fit(n=10000, local_rv=None, method='advi', model=None, seed=None, **kwargs):
564
574
elif isinstance (method , str ):
565
575
try :
566
576
inference = _select [method .lower ()](
567
- local_rv = local_rv , model = model , seed = seed
577
+ local_rv = local_rv , model = model , seed = seed ,
578
+ start = start
568
579
)
569
580
except KeyError :
570
581
raise KeyError ('method should be one of %s '
0 commit comments