Skip to content

Commit e9d7757

Browse files
committed
Add start_sigma to ADVI
1 parent c858f0f commit e9d7757

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

pymc/variational/approximations.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,33 @@ def std(self):
6767
def __init_group__(self, group):
6868
super().__init_group__(group)
6969
if not self._check_user_params():
70-
self.shared_params = self.create_shared_params(self._kwargs.get("start", None))
70+
self.shared_params = self.create_shared_params(
71+
self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
72+
)
7173
self._finalize_init()
7274

73-
def create_shared_params(self, start=None):
75+
def create_shared_params(self, start=None, start_sigma=None):
76+
# NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
77+
# `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
78+
# variables are given by `self.group` and that the mapping between original variables and flat array is given
79+
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
80+
# future code changes that break this assumption.
7481
start = self._prepare_start(start)
75-
rho = np.zeros((self.ddim,))
82+
rho = self._prepare_start_sigma(start_sigma)
7683
return {
7784
"mu": aesara.shared(pm.floatX(start), "mu"),
7885
"rho": aesara.shared(pm.floatX(rho), "rho"),
7986
}
8087

88+
def _prepare_start_sigma(self, start_sigma):
89+
rho = np.zeros((self.ddim,))
90+
if start_sigma is not None:
91+
for name, slice_, *_ in self.ordering.items():
92+
sigma = start_sigma.get(name)
93+
if sigma is not None:
94+
rho[slice_] = np.log(np.exp(np.abs(sigma)) - 1.0)
95+
return rho
96+
8197
@node_property
8298
def symbolic_random(self):
8399
initial = self.symbolic_initial

pymc/variational/inference.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,10 @@ class ADVI(KLqp):
433433
random_seed: None or int
434434
leave None to use package global RandomStream or other
435435
valid value to create instance specific one
436-
start: `Point`
436+
start: `dict[str, np.ndarray]` or `StartDict`
437437
starting point for inference
438+
start_sigma: `dict[str, np.ndarray]`
439+
starting standard deviation for inference, only available for method 'advi'
438440
439441
References
440442
----------
@@ -660,6 +662,7 @@ def fit(
660662
model=None,
661663
random_seed=None,
662664
start=None,
665+
start_sigma=None,
663666
inf_kwargs=None,
664667
**kwargs,
665668
):
@@ -684,8 +687,10 @@ def fit(
684687
valid value to create instance specific one
685688
inf_kwargs: dict
686689
additional kwargs passed to :class:`Inference`
687-
start: `Point`
690+
start: `dict[str, np.ndarray]` or `StartDict`
688691
starting point for inference
692+
start_sigma: `dict[str, np.ndarray]`
693+
starting standard deviation for inference, only available for method 'advi'
689694
690695
Other Parameters
691696
----------------
@@ -728,6 +733,10 @@ def fit(
728733
inf_kwargs["random_seed"] = random_seed
729734
if start is not None:
730735
inf_kwargs["start"] = start
736+
if start_sigma is not None:
737+
if method != "advi":
738+
raise NotImplementedError("start_sigma is only available for method advi")
739+
inf_kwargs["start_sigma"] = start_sigma
731740
if model is None:
732741
model = pm.modelcontext(model)
733742
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)

0 commit comments

Comments
 (0)