Skip to content

Commit ec27b5c

Browse files
Add start_sigma to ADVI (#6096)
* Add `start_sigma` to ADVI * add test for `start` and `start_sigma` plus minor fixes * inline _prepare_start_sigma and use expm1
1 parent bbb3082 commit ec27b5c

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

pymc/tests/test_variational_inference.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,28 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
571571
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)
572572

573573

574+
def test_fit_start(inference_spec, simple_model):
575+
mu_init = 17
576+
mu_sigma_init = 13
577+
578+
with simple_model:
579+
if type(inference_spec()) == ADVI:
580+
has_start_sigma = True
581+
else:
582+
has_start_sigma = False
583+
584+
kw = {"start": {"mu": mu_init}}
585+
if has_start_sigma:
586+
kw.update({"start_sigma": {"mu": mu_sigma_init}})
587+
588+
with simple_model:
589+
inference = inference_spec(**kw)
590+
trace = inference.fit(n=0).sample(10000)
591+
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
592+
if has_start_sigma:
593+
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
594+
595+
574596
def test_profile(inference):
575597
inference.run_profiling(n=100).summary()
576598

pymc/variational/approximations.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,27 @@ def std(self):
6767
def __init_group__(self, group):
6868
super().__init_group__(group)
6969
if not self._check_user_params():
70-
self.shared_params = self.create_shared_params(self._kwargs.get("start", None))
70+
self.shared_params = self.create_shared_params(
71+
self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
72+
)
7173
self._finalize_init()
7274

73-
def create_shared_params(self, start=None):
75+
def create_shared_params(self, start=None, start_sigma=None):
76+
# NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
77+
# `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
78+
# variables are given by `self.group` and that the mapping between original variables and flat array is given
79+
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
80+
# future code changes that break this assumption.
7481
start = self._prepare_start(start)
75-
rho = np.zeros((self.ddim,))
82+
rho1 = np.zeros((self.ddim,))
83+
84+
if start_sigma is not None:
85+
for name, slice_, *_ in self.ordering.values():
86+
sigma = start_sigma.get(name)
87+
if sigma is not None:
88+
rho1[slice_] = np.log(np.expm1(np.abs(sigma)))
89+
rho = rho1
90+
7691
return {
7792
"mu": aesara.shared(pm.floatX(start), "mu"),
7893
"rho": aesara.shared(pm.floatX(rho), "rho"),

pymc/variational/inference.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _infmean(input_array):
257257
)
258258
)
259259
else:
260-
if n < 10:
260+
if n == 0:
261+
logger.info(f"Initialization only")
262+
elif n < 10:
261263
logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}")
262264
else:
263265
avg_loss = _infmean(scores[max(0, i - 1000) : i + 1])
@@ -433,8 +435,10 @@ class ADVI(KLqp):
433435
random_seed: None or int
434436
leave None to use package global RandomStream or other
435437
valid value to create instance specific one
436-
start: `Point`
438+
start: `dict[str, np.ndarray]` or `StartDict`
437439
starting point for inference
440+
start_sigma: `dict[str, np.ndarray]`
441+
starting standard deviation for inference, only available for method 'advi'
438442
439443
References
440444
----------
@@ -464,7 +468,7 @@ class FullRankADVI(KLqp):
464468
random_seed: None or int
465469
leave None to use package global RandomStream or other
466470
valid value to create instance specific one
467-
start: `Point`
471+
start: `dict[str, np.ndarray]` or `StartDict`
468472
starting point for inference
469473
470474
References
@@ -532,13 +536,11 @@ class SVGD(ImplicitGradient):
532536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
533537
temperature: float
534538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
535-
start: `dict`
539+
start: `dict[str, np.ndarray]` or `StartDict`
536540
initial point for inference
537541
random_seed: None or int
538542
leave None to use package global RandomStream or other
539543
valid value to create instance specific one
540-
start: `Point`
541-
starting point for inference
542544
kwargs: other keyword arguments passed to estimator
543545
544546
References
@@ -629,7 +631,11 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
629631
"is often **underestimated** when using temperature = 1."
630632
)
631633
if approx is None:
632-
approx = FullRank(model=kwargs.pop("model", None))
634+
approx = FullRank(
635+
model=kwargs.pop("model", None),
636+
random_seed=kwargs.pop("random_seed", None),
637+
start=kwargs.pop("start", None),
638+
)
633639
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)
634640

635641
def fit(
@@ -660,6 +666,7 @@ def fit(
660666
model=None,
661667
random_seed=None,
662668
start=None,
669+
start_sigma=None,
663670
inf_kwargs=None,
664671
**kwargs,
665672
):
@@ -684,8 +691,10 @@ def fit(
684691
valid value to create instance specific one
685692
inf_kwargs: dict
686693
additional kwargs passed to :class:`Inference`
687-
start: `Point`
694+
start: `dict[str, np.ndarray]` or `StartDict`
688695
starting point for inference
696+
start_sigma: `dict[str, np.ndarray]`
697+
starting standard deviation for inference, only available for method 'advi'
689698
690699
Other Parameters
691700
----------------
@@ -728,6 +737,10 @@ def fit(
728737
inf_kwargs["random_seed"] = random_seed
729738
if start is not None:
730739
inf_kwargs["start"] = start
740+
if start_sigma is not None:
741+
if method != "advi":
742+
raise NotImplementedError("start_sigma is only available for method advi")
743+
inf_kwargs["start_sigma"] = start_sigma
731744
if model is None:
732745
model = pm.modelcontext(model)
733746
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)

0 commit comments

Comments
 (0)