-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adds quadratic approximation #4847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
78366d1
3f9a33a
bee10e1
f0e7030
e15c065
c65147c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2021 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Functions for Quadratic Approximation.""" | ||
|
||
import arviz as az | ||
import numpy as np | ||
import scipy | ||
|
||
__all__ = ["quadratic_approximation"] | ||
|
||
from pymc3.tuning import find_hessian, find_MAP | ||
|
||
|
||
def quadratic_approximation(vars, n_chains=2, n_samples=10_000): | ||
"""Finds the quadratic approximation to the posterior, also known as the Laplace approximation. | ||
|
||
NOTE: The quadratic approximation only works well for unimodal and roughly symmetrical posteriors of continuous variables. | ||
The usual MCMC convergence and mixing statistics (e.g. R-hat, ESS) will NOT tell you anything about how well this approximation fits your actual (unknown) posterior, indeed they'll always be extremely nice since all "chains" are sampling from exactly the same distribution, the posterior quadratic approximation. | ||
Use at your own risk. | ||
|
||
See Chapter 4 of "Bayesian Data Analysis" 3rd edition for background. | ||
|
||
Returns an arviz.InferenceData object for compatibility by sampling from the approximated quadratic posterior. Note these are NOT MCMC samples, so the notion of chains is meaningless, and is only included for downstream compatibility with Arviz. | ||
|
||
Also returns the exact posterior approximation as a scipy.stats.multivariate_normal distribution. | ||
|
||
Parameters | ||
---------- | ||
vars: list | ||
List of variables to approximate the posterior for. | ||
n_chains: int | ||
How many chains to simulate. | ||
n_samples: int | ||
How many samples to sample from the approximate posterior for each chain. | ||
|
||
Returns | ||
------- | ||
arviz.InferenceData: | ||
InferenceData with samples from the approximate posterior | ||
scipy.stats.multivariate_normal: | ||
Multivariate normal posterior approximation | ||
""" | ||
map = find_MAP(vars=vars) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any arguments to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. I think I'd say no. It's a leaky abstraction. If someone raises an issue where they need to pass args to find_MAP we can think about how to best do it at that time. That's my 2 cents. |
||
H = find_hessian(map, vars=vars) | ||
cov = np.linalg.inv(H) | ||
mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) | ||
posterior = scipy.stats.multivariate_normal(mean=mean, cov=cov) | ||
ColCarroll marked this conversation as resolved.
Show resolved
Hide resolved
|
||
draws = posterior.rvs((n_chains, n_samples)) | ||
samples = {} | ||
i = 0 | ||
for v in vars: | ||
var_size = map[v.name].size | ||
samples[v.name] = draws[:, :, i : i + var_size].squeeze() | ||
i += var_size | ||
return az.convert_to_inference_data(samples), posterior |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import arviz as az | ||
import numpy as np | ||
|
||
import pymc3 as pm | ||
|
||
from pymc3.tests.helpers import SeededTest | ||
|
||
|
||
class TestQuadraticApproximation(SeededTest): | ||
def setup_method(self): | ||
super().setup_method() | ||
|
||
def test_recovers_analytical_quadratic_approximation_in_normal_with_unknown_mean_and_variance(): | ||
y = np.array([2642, 3503, 4358]) | ||
n = y.size | ||
|
||
with pm.Model() as m: | ||
logsigma = pm.Uniform("logsigma", -100, 100) | ||
mu = pm.Uniform("mu", -10000, 10000) | ||
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) | ||
idata, posterior = pm.quadratic_approximation([mu, logsigma]) | ||
|
||
# BDA 3 sec. 4.1 - analytical solution | ||
bda_map = [y.mean(), np.log(y.std())] | ||
bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) | ||
|
||
assert np.allclose(posterior.mean, bda_map) | ||
assert np.allclose(posterior.cov, bda_cov, atol=1e-4) | ||
|
||
def test_hdi_contains_parameters_in_linear_regression(): | ||
N = 100 | ||
M = 2 | ||
sigma = 0.2 | ||
X = np.random.randn(N, M) | ||
A = np.random.randn(M) | ||
noise = sigma * np.random.randn(N) | ||
y = X @ A + noise | ||
|
||
with pm.Model() as lm: | ||
weights = pm.Normal("weights", mu=0, sigma=1, shape=M) | ||
noise = pm.Exponential("noise", lam=1) | ||
y_observed = pm.Normal("y_observed", mu=X @ weights, sigma=noise, observed=y) | ||
|
||
idata, _ = pm.quadratic_approximation([weights, noise]) | ||
|
||
hdi = az.hdi(idata) | ||
weight_hdi = hdi.weights.values | ||
assert np.all(np.bitwise_and(weight_hdi[0, :] < A, A < weight_hdi[1, :])) | ||
assert hdi.noise.values[0] < sigma < hdi.noise.values[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this kwarg and should fix it to 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, Arviz will complain if there aren't at least 2 chains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you share an example? I think you'll get trouble for multivariate distributions, where arviz will interpret the leading dimension as the chains, but I would solve that as follows:
But maybe I'm misunderstanding the issue!
I'm against including chains since it gives the impression that you might want to run diagnostics on the quality of the returned samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that's not too bad. All the plots seem to work. I just saw that warning and figured nothing would work. I'll remove the chains. I fully agree. The less impression we can give that running diagnostics the better. A little warning and a NaN might actually be good :)