-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Sample vp for Approximations, deprecate old ADVI #2027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
pymc3/variational/advi.py
Outdated
@@ -108,6 +108,9 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False, | |||
and Blei, D. M. (2016). Automatic Differentiation Variational | |||
Inference. arXiv preprint arXiv:1603.00788. | |||
""" | |||
import warnings | |||
warnings.warn('Old ADVI interface is deprecated and be removed in future, use pm.ADVI instead', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and will be
Seems like lots of refactoring will be needed for init_nuts |
pymc3/variational/advi.py
Outdated
@@ -108,6 +108,10 @@ def advi(vars=None, start=None, model=None, n=5000, accurate_elbo=False, | |||
and Blei, D. M. (2016). Automatic Differentiation Variational | |||
Inference. arXiv preprint arXiv:1603.00788. | |||
""" | |||
import warnings | |||
warnings.warn('Old ADVI interface is deprecated and will ' | |||
'be removed in future, use pm.ADVI instead', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we tell people to use pm.fit()
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
pymc3/variational/callbacks.py
Outdated
raise NotImplementedError | ||
|
||
|
||
class CheckLossConvergence(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very elegant.
We should also add the classic ADVI convergence criterion: https://github.com/pymc-devs/pymc3/blob/master/pymc3/variational/advi.py#L180
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay
@ferrine This starts to look pretty good. Why is init_nuts trickier than what you already have here? |
Do we want to do the |
@fonnesbeck The only option is to rename to |
@twiecki why do we need some strange tests for random states? |
@twiecki init nuts became a problem because of start point for fine-tuning advi from map. I did not support that case. |
@ferrine |
@fonnesbeck good point. So method of |
@ferrine that's a good idea, I agree. |
Comming back to the tests, they are failing. What for are these tests? https://github.com/pymc-devs/pymc3/blob/master/pymc3/tests/test_sampling.py#L37 |
That test was added when we fixed a bug where every parallel sampler generated the same samplers because they all used the same seed. |
Hmm, interesting |
@twiecki I can't find where I use global numpy rng in my code :( |
@ferrine Does that test fail locally too? |
yes, for sanity check I used the following code: s1 = np.random.get_state()
approx = pm.fit(
n=10, method='advi', model=nnet,
callbacks=[pm.callbacks.CheckLossConvergence(every=1, window_size=2)]
)
start = approx.sample(draws=4)
cov = approx.cov.eval()
s2 = np.random.get_state()
(s1[0] == s2[0]), (s1[1] == s2[1]).all(), (s1[2] == s2[2]), (s1[3] == s2[3])
--------------
True, True, True, True |
Good thing we have that test. |
Following discussion in #1953 CC @twiecki, @fonnesbeck, @aseyboldt, @jsalvatier, @taku-y, @springcoil
I've added Inference module to docs, how can I check that it will be displayed? |
Will there be a progress bar for |
Do you mean |
What about merge? |
pymc3/variational/approximations.py
Outdated
mapping {model_variable -> local_variable} | ||
local_rv : dict[var->tuple] | ||
Experimental for Empirical Distribution | ||
mapping {model_variable -> local_variable (:math:`\\mu`, math:`\\rho`)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
Looks great, thanks! |
v_params = pm.variational.advi(n=n_init, start=start, | ||
random_seed=random_seed) | ||
cov = np.power(model.dict_to_array(v_params.stds), 2) | ||
approx = pm.MeanField(model=model, start=start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting meaningful differences with this change on this model: http://twiecki.github.io/blog/2017/03/14/random-walk-deep-net/
It does not abort when converged and sampling from this initialization is very slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you see any way of improvement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@twiecki I was suggested to try running minimum stopping criteria. Interesting to see how it will affect performance
Following #2026