Skip to content

Commit bf843c7

Browse files
authored
Revert "Add start_sigma to ADVI (#6096)" (#6130)
This reverts commit ec27b5c.
1 parent 2296350 commit bf843c7

File tree

3 files changed

+11
-61
lines changed

3 files changed

+11
-61
lines changed

pymc/tests/test_variational_inference.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -571,28 +571,6 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
571571
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)
572572

573573

574-
def test_fit_start(inference_spec, simple_model):
575-
mu_init = 17
576-
mu_sigma_init = 13
577-
578-
with simple_model:
579-
if type(inference_spec()) == ADVI:
580-
has_start_sigma = True
581-
else:
582-
has_start_sigma = False
583-
584-
kw = {"start": {"mu": mu_init}}
585-
if has_start_sigma:
586-
kw.update({"start_sigma": {"mu": mu_sigma_init}})
587-
588-
with simple_model:
589-
inference = inference_spec(**kw)
590-
trace = inference.fit(n=0).sample(10000)
591-
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
592-
if has_start_sigma:
593-
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
594-
595-
596574
def test_profile(inference):
597575
inference.run_profiling(n=100).summary()
598576

pymc/variational/approximations.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,12 @@ def std(self):
6767
def __init_group__(self, group):
6868
super().__init_group__(group)
6969
if not self._check_user_params():
70-
self.shared_params = self.create_shared_params(
71-
self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
72-
)
70+
self.shared_params = self.create_shared_params(self._kwargs.get("start", None))
7371
self._finalize_init()
7472

75-
def create_shared_params(self, start=None, start_sigma=None):
76-
# NOTE: `Group._prepare_start` uses `self.model.free_RVs` to identify free variables and
77-
# `DictToArrayBijection` to turn them into a flat array, while `Approximation.rslice` assumes that the free
78-
# variables are given by `self.group` and that the mapping between original variables and flat array is given
79-
# by `self.ordering`. In the cases I looked into these turn out to be the same, but there may be edge cases or
80-
# future code changes that break this assumption.
73+
def create_shared_params(self, start=None):
8174
start = self._prepare_start(start)
82-
rho1 = np.zeros((self.ddim,))
83-
84-
if start_sigma is not None:
85-
for name, slice_, *_ in self.ordering.values():
86-
sigma = start_sigma.get(name)
87-
if sigma is not None:
88-
rho1[slice_] = np.log(np.expm1(np.abs(sigma)))
89-
rho = rho1
90-
75+
rho = np.zeros((self.ddim,))
9176
return {
9277
"mu": aesara.shared(pm.floatX(start), "mu"),
9378
"rho": aesara.shared(pm.floatX(rho), "rho"),

pymc/variational/inference.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ def _infmean(input_array):
257257
)
258258
)
259259
else:
260-
if n == 0:
261-
logger.info(f"Initialization only")
262-
elif n < 10:
260+
if n < 10:
263261
logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}")
264262
else:
265263
avg_loss = _infmean(scores[max(0, i - 1000) : i + 1])
@@ -435,10 +433,8 @@ class ADVI(KLqp):
435433
random_seed: None or int
436434
leave None to use package global RandomStream or other
437435
valid value to create instance specific one
438-
start: `dict[str, np.ndarray]` or `StartDict`
436+
start: `Point`
439437
starting point for inference
440-
start_sigma: `dict[str, np.ndarray]`
441-
starting standard deviation for inference, only available for method 'advi'
442438
443439
References
444440
----------
@@ -468,7 +464,7 @@ class FullRankADVI(KLqp):
468464
random_seed: None or int
469465
leave None to use package global RandomStream or other
470466
valid value to create instance specific one
471-
start: `dict[str, np.ndarray]` or `StartDict`
467+
start: `Point`
472468
starting point for inference
473469
474470
References
@@ -536,11 +532,13 @@ class SVGD(ImplicitGradient):
536532
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
537533
temperature: float
538534
parameter responsible for exploration, higher temperature gives more broad posterior estimate
539-
start: `dict[str, np.ndarray]` or `StartDict`
535+
start: `dict`
540536
initial point for inference
541537
random_seed: None or int
542538
leave None to use package global RandomStream or other
543539
valid value to create instance specific one
540+
start: `Point`
541+
starting point for inference
544542
kwargs: other keyword arguments passed to estimator
545543
546544
References
@@ -631,11 +629,7 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
631629
"is often **underestimated** when using temperature = 1."
632630
)
633631
if approx is None:
634-
approx = FullRank(
635-
model=kwargs.pop("model", None),
636-
random_seed=kwargs.pop("random_seed", None),
637-
start=kwargs.pop("start", None),
638-
)
632+
approx = FullRank(model=kwargs.pop("model", None))
639633
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)
640634

641635
def fit(
@@ -666,7 +660,6 @@ def fit(
666660
model=None,
667661
random_seed=None,
668662
start=None,
669-
start_sigma=None,
670663
inf_kwargs=None,
671664
**kwargs,
672665
):
@@ -691,10 +684,8 @@ def fit(
691684
valid value to create instance specific one
692685
inf_kwargs: dict
693686
additional kwargs passed to :class:`Inference`
694-
start: `dict[str, np.ndarray]` or `StartDict`
687+
start: `Point`
695688
starting point for inference
696-
start_sigma: `dict[str, np.ndarray]`
697-
starting standard deviation for inference, only available for method 'advi'
698689
699690
Other Parameters
700691
----------------
@@ -737,10 +728,6 @@ def fit(
737728
inf_kwargs["random_seed"] = random_seed
738729
if start is not None:
739730
inf_kwargs["start"] = start
740-
if start_sigma is not None:
741-
if method != "advi":
742-
raise NotImplementedError("start_sigma is only available for method advi")
743-
inf_kwargs["start_sigma"] = start_sigma
744731
if model is None:
745732
model = pm.modelcontext(model)
746733
_select = dict(advi=ADVI, fullrank_advi=FullRankADVI, svgd=SVGD, asvgd=ASVGD)

0 commit comments

Comments
 (0)